statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,513 @@
1
+ """FISTA with Barzilai-Borwein step sizes and adaptive restart.
2
+
3
+ Uses alternating BB1/BB2 steps (Barzilai & Borwein 1988) that adapt to
4
+ local curvature, eliminating the backtracking line search while preserving
5
+ sparsity. BB1 = <dw,dw>/<dw,dg> (long step), BB2 = <dw,dg>/<dg,dg>
6
+ (short step). Adaptive restart (O'Donoghue & Candes 2015) resets
7
+ momentum when it opposes the descent direction.
8
+
9
+ Supports numpy / cupy / torch backends via auto-detection of X.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ __all__ = ["fista_bb_solver"]
15
+
16
+ import warnings
17
+ import numpy as np
18
+ from statgpu.backends import _resolve_backend, _to_numpy
19
+ from statgpu.backends._utils import _to_float_scalar
20
+ from statgpu.backends._array_ops import (
21
+ _abs_sum_dev, _clip_grad_on_device, _copy_arr, _dot_dev,
22
+ _norm2_dev, _sync_scalars, _zeros,
23
+ )
24
+ from statgpu.penalties._categories import BB_DISABLED as _BB_DISABLED
25
+ from ._convergence import ConvergenceWarning
26
+ from ._constants import (
27
+ _DIVERGE_COEF_NORM_CAP,
28
+ _BB_RESTART_DOT_TOL,
29
+ _DIVERGE_OBJ_RATIO,
30
+ _DIVERGE_OBJ_ABS,
31
+ _GRAD_CLIP_COEF_FACTOR,
32
+ _GRAD_CLIP_ABS_FLOOR,
33
+ _GRAD_CLIP_MAX,
34
+ )
35
+ from ._fista import fista_solver
36
+ from ._utils import (
37
+ _validate_sample_weight,
38
+ _as_backend_vector,
39
+ _call_with_weight,
40
+ _nesterov_update,
41
+ _penalty_name,
42
+ _smooth_penalty_lipschitz,
43
+ _tracking_penalty_value,
44
+ _abs_mean_max,
45
+ )
46
+
47
+
48
+ def fista_bb_solver(
49
+ loss,
50
+ penalty: "Penalty | None",
51
+ X,
52
+ y,
53
+ max_iter: int = 1000,
54
+ tol: float = 1e-4,
55
+ init_coef=None,
56
+ sample_weight=None,
57
+ use_restart: bool = True,
58
+ step_max_factor: float = 1e3,
59
+ step_min_factor: float = 1e-3,
60
+ bb_burn_in: int = 20,
61
+ cv_mode: bool = False,
62
+ lipschitz_L: float | None = None,
63
+ ) -> tuple:
64
+ """FISTA with Barzilai-Borwein step sizes and adaptive restart.
65
+
66
+ Uses alternating BB1/BB2 steps (Barzilai & Borwein 1988) that adapt to
67
+ local curvature, eliminating the backtracking line search while preserving
68
+ sparsity. BB1 = <dw,dw>/<dw,dg> (long step), BB2 = <dw,dg>/<dg,dg>
69
+ (short step). Adaptive restart (O'Donoghue & Candes 2015) resets
70
+ momentum when it opposes the descent direction.
71
+
72
+ Supports numpy / cupy / torch backends via auto-detection of X.
73
+ """
74
+ backend = _resolve_backend("auto", X)
75
+ _is_gpu = backend in ("torch", "cupy")
76
+ X_proc, y_proc = loss.preprocess(X, y)
77
+ n_features = X_proc.shape[1]
78
+ _pen_name = _penalty_name(penalty)
79
+
80
+ # Smooth logistic objectives are better handled by the Armijo-backed FISTA
81
+ # path. This keeps explicit fista_bb numerically aligned across CPU/CuPy/
82
+ # Torch for logistic+none/l2 Section A checks.
83
+ if getattr(loss, '_prefer_fista_over_bb', False) and _pen_name in ("l2", "none", "null", ""):
84
+ return fista_solver(
85
+ loss,
86
+ penalty,
87
+ X,
88
+ y,
89
+ max_iter=max_iter,
90
+ tol=tol,
91
+ init_coef=init_coef,
92
+ sample_weight=sample_weight,
93
+ cv_mode=cv_mode,
94
+ )
95
+
96
+ # --- Initialize coefficients ---
97
+ if init_coef is not None:
98
+ coef = _as_backend_vector(init_coef, backend, X)
99
+ else:
100
+ coef = _zeros(n_features, backend, ref_tensor=X)
101
+
102
+ y_k = _copy_arr(coef)
103
+ t_k = 1.0
104
+
105
+ # Divergence detection: track best objective for recovery
106
+ _obj_best = float('inf')
107
+ _coef_best = None
108
+ _diverge_count = 0
109
+
110
+ _bb_use_long = True # alternate BB1 / BB2
111
+ dot_dw_dg = 0.0 # BB step numerator (initialized for bb_burn_in=0)
112
+ dot_dw_dw = 1.0 # BB step denominator
113
+ _div_check_interval = 25 if cv_mode and _is_gpu else 5
114
+ _lip_check_interval = 25 if cv_mode and _is_gpu else 5
115
+ _conv_check_interval = 10 if cv_mode and _is_gpu else 3
116
+ # For quadratic losses (squared_error) the gradient is linear in coef,
117
+ # so dg = H @ dw and BB1 = BB2 = 1 / Rayleigh_quotient(H, dw). The BB
118
+ # step gives zero adaptation and the algorithm degenerates to ISTA
119
+ # (O(1/k) convergence), too slow to reach the true sparse solution
120
+ # within max_iter. Use standard FISTA (fixed Lipschitz step + Nesterov
121
+ # momentum, O(1/k^2)) instead.
122
+ _is_quadratic = getattr(loss, '_is_quadratic', False)
123
+
124
+ # BB steps estimate local curvature from smooth-gradient differences.
125
+ # For non-smooth penalties the proximal operator introduces a
126
+ # discontinuity that makes the gradient differences noisy.
127
+ #
128
+ # On quadratic losses (squared_error) BB adds nothing — BB1 = BB2 =
129
+ # 1/R_H and the method degenerates to ISTA (O(1/k)). _is_quadratic
130
+ # already disables BB above.
131
+ #
132
+ # For GLM losses with convex non-smooth penalties (L1, elasticnet,
133
+ # adaptive_l1) the subgradient is bounded and BB differences are valid
134
+ # after a burn-in that lets the iterates stabilise. This gives 2-3x
135
+ # faster convergence for logistic+L1, poisson+L1, etc.
136
+ #
137
+ # For non-convex non-smooth penalties (SCAD, MCP, group_*) the
138
+ # subgradient can change abruptly (reweighting, folding points),
139
+ # amplifying noise through the non-linear link and causing catastrophic
140
+ # divergence. Disable BB entirely for these.
141
+ _pen_name = getattr(penalty, "name", _pen_name).lower() if hasattr(getattr(penalty, "name", _pen_name), 'lower') else _pen_name
142
+ if _pen_name in _BB_DISABLED:
143
+ bb_burn_in = max_iter + 1 # never switch to BB
144
+ elif _pen_name in {"l1", "elasticnet", "en", "adaptive_l1", "adaptive_lasso"}:
145
+ bb_burn_in = max(bb_burn_in, 50) # longer burn-in for non-smooth
146
+
147
+ # Initial Lipschitz at zero (safe for all losses). Computing L at
148
+ # init_coef can produce enormous values for exp-link families (mu =
149
+ # exp(X@coef) explodes for warm-start coefs from OLS).
150
+ _zero_coef_bb = _zeros(n_features, backend, ref_tensor=X)
151
+ _cached_lipschitz_L = None
152
+ if lipschitz_L is not None:
153
+ try:
154
+ _cached_lipschitz_L = float(_to_numpy(lipschitz_L))
155
+ except (ValueError, TypeError):
156
+ _cached_lipschitz_L = None
157
+ if _cached_lipschitz_L is not None and _cached_lipschitz_L > 0:
158
+ L = _cached_lipschitz_L
159
+ else:
160
+ _cached_lipschitz_L = None
161
+ L = _call_with_weight(loss.lipschitz, X_proc, _zero_coef_bb, y=y_proc, sample_weight=sample_weight)
162
+ if L <= 0:
163
+ L = 1.0
164
+ # For GLM losses with exp link (Poisson, etc.), mu at coef=0
165
+ # is ~1, but mu near the optimum ~ y. The Hessian X'@diag(mu)@X
166
+ # scales linearly with mu, so Lipschitz at init can underestimate the
167
+ # true curvature by orders of magnitude (e.g. max(y)=2865 vs init mu=1).
168
+ # Use geometric-mean heuristic: robust against extreme outliers while
169
+ # still scaling up enough to avoid oversized first steps.
170
+ # Logistic: BB step handles adaptation, y-scaling causes divergence.
171
+ # Gamma's expected Fisher Hessian X'X/n underestimates
172
+ # true curvature by ~mean(y), so y-scaling IS needed.
173
+ _skip_y_scaling_bb = getattr(loss, '_lipschitz_uses_y', False)
174
+ _y_scale = 1.0 # default; overridden below for families that need it
175
+ if not _is_quadratic and not _skip_y_scaling_bb:
176
+ _y_mean, _y_max = _abs_mean_max(y_proc, backend)
177
+ _y_scale = max(1.0, _y_mean, np.sqrt(_y_mean * _y_max))
178
+ if _y_scale > 1.0:
179
+ L = L * _y_scale
180
+ # Inverse Gaussian: gradient scales as 1/mu^3, causing extreme
181
+ # sensitivity to step size. Use a much more conservative Lipschitz
182
+ # to prevent catastrophic divergence.
183
+ _invgauss_like = getattr(loss, '_inverse_gaussian', False)
184
+ _tweedie_like = getattr(loss, '_tweedie', False)
185
+ _lip_safety_bb = getattr(loss, '_lipschitz_safety', 1.0)
186
+ if _lip_safety_bb > 1.0:
187
+ L = L * _lip_safety_bb
188
+ # Add smooth penalty Lipschitz contribution (e.g. l2 gradient alpha*coef
189
+ # has Lipschitz alpha). Without this the step 1/L is too large.
190
+ _smooth_lip_bb = _smooth_penalty_lipschitz(penalty)
191
+ if _smooth_lip_bb > 0:
192
+ L = L + _smooth_lip_bb
193
+ step_L = 1.0 / L
194
+ step_k = step_L
195
+ step_max = step_L * step_max_factor
196
+ step_min = step_L * step_min_factor
197
+ _validate_sample_weight(sample_weight, X_proc.shape[0])
198
+
199
+ # Gradient at initial point for first BB difference
200
+ grad_old = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
201
+ # Initialize dg for BB step selection (used before first assignment in loop)
202
+ dg = _zeros(n_features, backend, ref_tensor=X_proc)
203
+ iteration = -1 # default if max_iter=0
204
+
205
+ # Loop-invariant constants for momentum/BB decisions
206
+ _poisson_like = getattr(loss, '_poisson_like', False)
207
+ _gamma_like = getattr(loss, '_gamma_like', False)
208
+
209
+ # --- Pre-compute loop-invariant burn-in and momentum parameters ---
210
+ # These depend only on loss/penalty type, not on iterates.
211
+ if _invgauss_like:
212
+ bb_burn_in = max_iter + 1 # never switch to BB
213
+ elif _tweedie_like:
214
+ bb_burn_in = max(200, max_iter // 2)
215
+ elif _gamma_like:
216
+ bb_burn_in = max(50, max_iter // 8)
217
+
218
+ _momentum_disabled = getattr(loss, '_momentum_disabled', False)
219
+ if _momentum_disabled:
220
+ _momentum_burn_in = max_iter + 1 # never use momentum
221
+ elif _tweedie_like:
222
+ _momentum_burn_in = max(100, max_iter // 4)
223
+ elif _gamma_like:
224
+ _momentum_burn_in = max(30, max_iter // 10)
225
+ else:
226
+ _momentum_burn_in = 0 # momentum from the start
227
+
228
+ # Conservative momentum for specific loss+penalty combos
229
+ _momentum_beta_cap = getattr(loss, '_momentum_beta_cap', None)
230
+ if _momentum_beta_cap is not None and _poisson_like and not _invgauss_like:
231
+ _pen_name_bb = getattr(penalty, 'name', '')
232
+ if _pen_name_bb in ("l2", "none", "", None):
233
+ _momentum_burn_in = min(100, max_iter)
234
+ if _tweedie_like or _gamma_like:
235
+ if _momentum_beta_cap is None:
236
+ _momentum_beta_cap = 0.2
237
+
238
+ for iteration in range(max_iter):
239
+ coef_old = _copy_arr(coef)
240
+
241
+ # Gradient at extrapolated point
242
+ grad = _call_with_weight(loss.gradient, X_proc, y_proc, y_k, sample_weight=sample_weight)
243
+
244
+ # Clip extreme gradients -- every iteration, all backends.
245
+ # Skip for inverse_gaussian: 1/mu^3 gradient scaling produces large but
246
+ # valid gradients; clipping prevents convergence to the true optimum.
247
+ # Use identical sync-based clipping for both CPU and GPU to ensure
248
+ # consistent trajectories (backtracking already syncs for non-quadratic).
249
+ if not _invgauss_like:
250
+ if cv_mode and _is_gpu:
251
+ grad = _clip_grad_on_device(grad, coef_old, backend)
252
+ else:
253
+ _gn_f, _coef_abs_f = _sync_scalars(
254
+ _norm2_dev(grad), _abs_sum_dev(coef_old), backend=backend)
255
+ _gmax = max(_coef_abs_f * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
256
+ if _gn_f > _gmax:
257
+ grad = grad * (_gmax / _gn_f)
258
+
259
+ # --- Divergence detection ---
260
+ # Full objective check every 5 iterations (GPU optimization: reduces
261
+ # expensive loss.value() calls). Coefficient norm check every iteration
262
+ # (cheap) catches catastrophic explosion early.
263
+ # Batch obj + coef-norm into a single sync when both are needed.
264
+ _do_full_div_check = (
265
+ iteration % _div_check_interval == 0 or iteration <= 5
266
+ )
267
+ # GPU: defer ALL divergence checks to every 5 iterations (no per-iter sync)
268
+ _do_div_check = (not _is_quadratic and iteration > 0 and
269
+ (not _is_gpu or _do_full_div_check))
270
+ if _do_div_check:
271
+ _diverged = False
272
+ # Coef norm divergence check (works for both CPU and GPU)
273
+ if iteration > 10 and not _diverged:
274
+ _coef_norm_dev = _norm2_dev(coef)
275
+ if _to_float_scalar(_coef_norm_dev) > _DIVERGE_COEF_NORM_CAP:
276
+ _diverged = True
277
+ # Full objective check every 5 iterations
278
+ if not _diverged:
279
+ _obj_val = float(_to_numpy(_call_with_weight(loss.value, X_proc, y_proc, coef, sample_weight=sample_weight)))
280
+ _pen_val = _tracking_penalty_value(penalty, coef)
281
+ _obj_total = _obj_val + _pen_val
282
+ if not np.isfinite(_obj_total):
283
+ _diverged = True
284
+ elif not np.isfinite(_obj_best):
285
+ # _obj_best is inf/-inf (first valid iter or degenerate loss):
286
+ # skip ratio-based check, rely on norm check above.
287
+ pass
288
+ elif _obj_best > 1e-8:
289
+ _diverge_threshold = _obj_best * 10.0 + 1e-8
290
+ if _invgauss_like or _tweedie_like:
291
+ _diverge_threshold = _obj_best * _DIVERGE_OBJ_RATIO + _DIVERGE_OBJ_ABS
292
+ _diverged = _obj_total > _diverge_threshold
293
+ else:
294
+ _diverge_threshold = _obj_best + max(abs(_obj_best) * 10.0, 1.0)
295
+ if _invgauss_like or _tweedie_like:
296
+ _diverge_threshold = _obj_best + max(abs(_obj_best) * _DIVERGE_OBJ_RATIO, _DIVERGE_OBJ_ABS)
297
+ _diverged = _obj_total > _diverge_threshold
298
+ if _diverged:
299
+ # Diverged: reset to best known iterate (or zeros) and halve step
300
+ _diverge_count += 1
301
+ if _coef_best is not None:
302
+ coef = _copy_arr(_coef_best)
303
+ else:
304
+ # No valid iterate yet -- reset to zeros
305
+ coef = _zeros(n_features, backend, ref_tensor=X_proc)
306
+ y_k = _copy_arr(coef)
307
+ t_k = 1.0
308
+ grad_old = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
309
+ # Halve step size bounds
310
+ step_L = step_L * 0.5
311
+ step_k = step_L
312
+ step_max = step_max * 0.5
313
+ step_min = step_min * 0.5
314
+ L = L * 2.0
315
+ # Reset BB state
316
+ dot_dw_dg = 0.0
317
+ dot_dw_dw = 1.0
318
+ continue
319
+ elif _obj_total < _obj_best:
320
+ _obj_best = _obj_total
321
+ _coef_best = _copy_arr(coef)
322
+
323
+ # --- Step size selection ---
324
+ if _is_quadratic or iteration < bb_burn_in:
325
+ # Quadratic loss or burn-in phase: use fixed Lipschitz step.
326
+ # During burn-in for GLM losses, BB steps are delayed because
327
+ # early gradient differences (dw, dg) are dominated by the
328
+ # coef trajectory from zero toward the optimum rather than by
329
+ # local curvature; using BB too early amplifies oscillations.
330
+ step_k = step_L
331
+ # Recompute Lipschitz periodically during burn-in since mu
332
+ # (and therefore the Hessian scale) changes rapidly.
333
+ if (
334
+ not _is_quadratic
335
+ and iteration > 0
336
+ and iteration % _lip_check_interval == 0
337
+ ):
338
+ # Use global Lipschitz (coef=zero) during burn-in to prevent
339
+ # iterate-dependent Lipschitz from shrinking too fast.
340
+ # BB steps handle adaptation after burn-in.
341
+ # Pass zero coef -- not all losses handle coef=None.
342
+ if _cached_lipschitz_L is not None:
343
+ L_new = _cached_lipschitz_L
344
+ else:
345
+ L_new = loss.lipschitz(X_proc, _zero_coef_bb, y=y_proc)
346
+ if L_new > 0:
347
+ # Re-apply y-scaling and per-family safety factor
348
+ if _y_scale > 1.0:
349
+ L_new = L_new * _y_scale
350
+ _lip_safety_bt = getattr(loss, '_lipschitz_safety', 1.0)
351
+ if _lip_safety_bt > 1.0:
352
+ L_new = L_new * _lip_safety_bt
353
+ # Allow L to move toward L_new: full increase, gradual decrease
354
+ if L_new > L:
355
+ L = L_new
356
+ else:
357
+ L = max(L * 0.8, L_new)
358
+ step_L = 1.0 / L
359
+ step_k = step_L
360
+ step_max = step_L * step_max_factor
361
+ step_min = step_L * step_min_factor
362
+ else:
363
+ # Nonlinear GLM loss, post-burn-in: use BB step when valid,
364
+ # fall back to Lipschitz step otherwise.
365
+ if dot_dw_dg > _BB_RESTART_DOT_TOL:
366
+ if _bb_use_long:
367
+ step_k = dot_dw_dw / dot_dw_dg # BB1: long
368
+ else:
369
+ dot_dg_dg = float(_to_numpy(_dot_dev(dg, dg)))
370
+ step_k = dot_dw_dg / max(dot_dg_dg, 1e-14) # BB2: short
371
+ _bb_use_long = not _bb_use_long
372
+ # Tweedie: cap BB step more aggressively to prevent overshoot
373
+ if _tweedie_like:
374
+ step_k = min(step_k, step_L * 2.0)
375
+ step_k = min(max(step_k, step_min), step_max)
376
+ # else: keep previous step_k (step_L or last valid BB step)
377
+
378
+ # Gradient step + proximal
379
+ w_tilde = y_k - step_k * grad
380
+ coef_new = penalty.proximal(w_tilde, step_k, backend=backend)
381
+ coef = coef_new
382
+
383
+ # Safeguarded backtracking for GLM losses:
384
+ # After proximal, verify the objective didn't explode. If it did,
385
+ # halve step and recompute. This catches cases where the BB step
386
+ # or Lipschitz estimate was too optimistic for the new coef region.
387
+ # Interval-based: full objective check every 5 iterations (expensive
388
+ # loss.value() call), cheap norm check every iteration.
389
+ _last_coef_norm_f = None
390
+ if not _is_quadratic:
391
+ _steep_loss = getattr(loss, '_steep_loss', False)
392
+ # Interval-based: only run expensive objective check every 5 iters
393
+ # (divergence detection above also checks every 5 iters)
394
+ _do_bt_check = (iteration % 5 == 0 or iteration <= 5)
395
+ if _do_bt_check:
396
+ for _bt in range(15):
397
+ # Batch obj + coef-norm into a single sync.
398
+ _new_obj, _new_norm = _sync_scalars(
399
+ loss.value(X_proc, y_proc, coef), _norm2_dev(coef), backend=backend)
400
+ _new_pen = _tracking_penalty_value(penalty, coef)
401
+ _new_total = _new_obj + _new_pen
402
+ # Accept if: finite, reasonable norm, and objective not exploded.
403
+ # Use relative threshold (10x initial objective) instead of
404
+ # absolute 1e6 -- NB/Tweedie with large counts can have
405
+ # legitimate loss > 1e6.
406
+ _obj_cap = max(_obj_best * 10.0, 1e6) if np.isfinite(_obj_best) else 1e6
407
+ if _steep_loss:
408
+ _obj_acceptable = (np.isfinite(_new_total) and _new_norm < _DIVERGE_COEF_NORM_CAP and
409
+ _new_total < _obj_cap)
410
+ else:
411
+ # For logistic/gamma/poisson: accept if finite, reasonable
412
+ # norm, and objective not significantly worse than best known.
413
+ _obj_acceptable = (np.isfinite(_new_total) and _new_norm < _DIVERGE_COEF_NORM_CAP and
414
+ _new_total < max(_obj_best * 1.5 + 1.0, 1e3))
415
+ if _obj_acceptable:
416
+ _last_coef_norm_f = _new_norm
417
+ break
418
+ # Step too large -- halve and retry
419
+ step_k = step_k * 0.5
420
+ L = L * 2.0
421
+ w_tilde = y_k - step_k * grad
422
+ coef = penalty.proximal(w_tilde, step_k, backend=backend)
423
+ _last_coef_norm_f = None
424
+
425
+ # Finiteness check: if coef is non-finite after proximal, reset.
426
+ # Reuse the norm already synchronized by safeguarded backtracking.
427
+ if not _is_quadratic:
428
+ if _last_coef_norm_f is not None:
429
+ _finite_ok2 = np.isfinite(_last_coef_norm_f)
430
+ else:
431
+ _coef_norm_dev2 = _norm2_dev(coef)
432
+ _finite_ok2 = np.isfinite(_to_float_scalar(_coef_norm_dev2))
433
+ if not _finite_ok2:
434
+ _diverge_count += 1
435
+ if _coef_best is not None:
436
+ coef = _copy_arr(_coef_best)
437
+ y_k = _copy_arr(coef)
438
+ t_k = 1.0
439
+ grad_old = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
440
+ step_L = step_L * 0.5
441
+ step_k = step_L
442
+ step_max = step_max * 0.5
443
+ step_min = step_min * 0.5
444
+ L = L * 2.0
445
+ dot_dw_dg = 0.0
446
+ dot_dw_dw = 1.0
447
+ continue
448
+
449
+ # --- Store BB step info for next iteration (non-quadratic only) ---
450
+ # Use accepted iterate (coef) not pre-backtracking (coef_new)
451
+ if not _is_quadratic:
452
+ grad_new = _call_with_weight(loss.gradient, X_proc, y_proc, coef, sample_weight=sample_weight)
453
+
454
+ dw = coef - coef_old
455
+ dg = grad_new - grad_old
456
+ # Batch two dot products into a single GPU->CPU sync.
457
+ dot_dw_dw, dot_dw_dg = _sync_scalars(
458
+ _dot_dev(dw, dw), _dot_dev(dw, dg), backend=backend)
459
+ grad_old = grad_new
460
+
461
+ # --- Nesterov momentum with adaptive restart ---
462
+ # bb_burn_in, _momentum_burn_in, _momentum_beta_cap are loop-invariant
463
+ # and computed once before the loop.
464
+ if iteration < _momentum_burn_in:
465
+ t_k = 1.0
466
+ beta = 0.0
467
+ y_k = _copy_arr(coef) # next gradient at current point, not extrapolated
468
+ elif _momentum_beta_cap is not None:
469
+ # Conservative momentum: fixed small beta to avoid explosion
470
+ beta = _momentum_beta_cap
471
+ y_k = coef + beta * (coef - coef_old)
472
+ t_k = 1.0
473
+ else:
474
+ y_k, t_new = _nesterov_update(coef, coef_old, t_k)
475
+ beta = (t_k - 1.0) / t_new
476
+
477
+ if use_restart and iteration > 0:
478
+ # GPU-side comparison, only sync bool.
479
+ # Use `coef` (always current) not `coef_new` (stale after reset).
480
+ _mc_dev = _dot_dev(y_k - coef, coef - coef_old)
481
+ if _to_float_scalar(_mc_dev) > 0:
482
+ t_k = 1.0
483
+ t_new = 1.0
484
+ beta = 0.0
485
+ y_k = coef + beta * (coef - coef_old)
486
+
487
+ t_k = t_new
488
+
489
+ # --- Convergence check -- deferred for GPU, every iteration for CPU. ---
490
+ if _is_gpu:
491
+ if iteration < 20 or iteration % _conv_check_interval == 0:
492
+ _conv_dev2 = _abs_sum_dev(coef - coef_old)
493
+ if _to_float_scalar(_conv_dev2) < tol:
494
+ break
495
+ else:
496
+ _conv_dev2 = _abs_sum_dev(coef - coef_old)
497
+ if _to_float_scalar(_conv_dev2) < tol:
498
+ break
499
+
500
+ # Return best iterate if divergence was detected
501
+ if _diverge_count > 0 and _coef_best is not None:
502
+ coef = _copy_arr(_coef_best)
503
+
504
+ n_iter = iteration + 1
505
+ if n_iter >= max_iter:
506
+ warnings.warn(
507
+ f"fista_bb_solver did not converge within {max_iter} iterations "
508
+ f"(loss={getattr(loss, 'name', '?')}, penalty={getattr(penalty, 'name', '?')}). "
509
+ f"Consider increasing max_iter or using a different solver (newton, lbfgs, irls).",
510
+ ConvergenceWarning,
511
+ stacklevel=2,
512
+ )
513
+ return coef, n_iter