statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,436 @@
1
+ """FISTA solver with backtracking line search.
2
+
3
+ minimize: loss(X, y, w) + penalty(w)
4
+
5
+ Supports numpy / cupy / torch backends via auto-detection.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ __all__ = ["fista_solver"]
11
+
12
+ import warnings
13
+ import numpy as np
14
+ from statgpu.backends import _resolve_backend, _to_numpy
15
+ from statgpu.backends._utils import _to_float_scalar, _get_xp
16
+ from statgpu.backends._array_ops import (
17
+ _abs_sum_dev,
18
+ _clip_grad_on_device,
19
+ _copy_arr,
20
+ _dot_dev,
21
+ _norm2_dev,
22
+ _sum_sq_dev,
23
+ _sync_scalars,
24
+ _zeros,
25
+ )
26
+ from ._convergence import ConvergenceWarning
27
+ from ._constants import (
28
+ _SLACK_TOLERANCE,
29
+ _DIVERGE_COEF_NORM_CAP,
30
+ _LIPSCHITZ_SAFETY_LOGISTIC_CV,
31
+ _GRAD_CLIP_COEF_FACTOR,
32
+ _GRAD_CLIP_ABS_FLOOR,
33
+ _GRAD_CLIP_MAX,
34
+ )
35
+ from ._utils import (
36
+ _validate_sample_weight,
37
+ _as_backend_vector,
38
+ _call_with_weight,
39
+ _nesterov_update,
40
+ _penalty_name,
41
+ _smooth_penalty_lipschitz,
42
+ _abs_mean_max,
43
+ _tracking_penalty_value,
44
+ )
45
+
46
+
47
+ def fista_solver(
48
+ loss: "GLMLoss",
49
+ penalty: "Penalty | None",
50
+ X,
51
+ y,
52
+ max_iter: int = 1000,
53
+ tol: float = 1e-4,
54
+ init_coef=None,
55
+ sample_weight=None,
56
+ lipschitz_L: float | None = None,
57
+ cv_mode: bool = False,
58
+ ) -> tuple:
59
+ """General FISTA solver with backtracking line search.
60
+
61
+ Supports numpy / cupy / torch backends via auto-detection of X.
62
+
63
+ Parameters
64
+ ----------
65
+ loss : GLMLoss
66
+ GLM loss function with gradient(), lipschitz(), preprocess(), value().
67
+ penalty : Penalty
68
+ Penalty with proximal().
69
+ X : array
70
+ Design matrix (numpy/cupy/torch).
71
+ y : array
72
+ Target (numpy/cupy/torch).
73
+ max_iter : int
74
+ Maximum iterations.
75
+ tol : float
76
+ Convergence tolerance.
77
+ init_coef : array, optional
78
+ Initial coefficient vector.
79
+ sample_weight : array, optional
80
+ Per-sample weights. Non-uniform weights are currently rejected in this
81
+ solver path to avoid silently running an incorrect unweighted update.
82
+ cv_mode : bool, default=False
83
+ Private CV fast path: keeps the same update rule but checks objective
84
+ and convergence less often on GPU non-smooth GLM paths.
85
+
86
+ Returns
87
+ -------
88
+ coef : array
89
+ Fitted coefficients (same backend as X).
90
+ n_iter : int
91
+ Number of iterations.
92
+ """
93
+ backend = _resolve_backend("auto", X)
94
+ X_proc, y_proc = loss.preprocess(X, y)
95
+ _is_quadratic = getattr(loss, '_is_quadratic', False)
96
+ # Momentum control via loss class attributes:
97
+ # _momentum_beta_cap: if set, cap Nesterov beta at this value
98
+ # _skip_momentum: if True, disable momentum entirely
99
+ # Conservative momentum (cap beta at 0.5) for exp-link families and
100
+ # for logistic/gamma with non-smooth penalties. Logistic/gamma with
101
+ # smooth penalties (none, l2) benefit from full Nesterov acceleration.
102
+ _momentum_beta_cap = getattr(loss, '_momentum_beta_cap', None)
103
+ _skip_momentum = getattr(loss, '_skip_momentum', False)
104
+
105
+ n_features = X_proc.shape[1]
106
+ if init_coef is not None:
107
+ coef = _as_backend_vector(init_coef, backend, X)
108
+ else:
109
+ coef = _zeros(n_features, backend, ref_tensor=X)
110
+
111
+ y_k = _copy_arr(coef)
112
+ t_k = 1.0
113
+
114
+ # Divergence detection: track best objective for recovery
115
+ _obj_best_fista = float('inf')
116
+ _coef_best_fista = None
117
+
118
+ # Initial Lipschitz: default to zero (safe for exp-link warm starts),
119
+ # but allow losses to request evaluation at the provided init to avoid
120
+ # degenerate curvature from eta=0 clipping.
121
+ _cached_XtWX_weighted = None # populated in Lipschitz block, used in GPU loop
122
+ if lipschitz_L is not None and lipschitz_L > 0:
123
+ L = lipschitz_L
124
+ else:
125
+ if getattr(loss, '_lipschitz_at_init', False):
126
+ _lip_coef = _copy_arr(coef)
127
+ else:
128
+ _lip_coef = _zeros(n_features, backend, ref_tensor=X)
129
+ if sample_weight is not None:
130
+ # Weighted Lipschitz: eigenvalue of X' diag(w) X / sum(w)
131
+ _xp_mod = _get_xp(backend)
132
+ # Ensure sample_weight is on same backend as X_proc
133
+ _sw_np = _to_numpy(sample_weight)
134
+ _sw = _xp_mod.asarray(_sw_np, dtype=X_proc.dtype)
135
+ sw_sum = _to_float_scalar(_xp_mod.sum(_sw))
136
+ sw_col = _sw[:, None] if _sw.ndim == 1 else _sw
137
+ XtWX = X_proc.T @ (X_proc * sw_col) / sw_sum
138
+ L = _to_float_scalar(_xp_mod.max(_xp_mod.diag(XtWX))) # conservative bound
139
+ if L <= 0:
140
+ L = 1.0
141
+ # Cache for periodic recomputation in the loop (X and weights are constant)
142
+ _cached_XtWX_weighted = XtWX
143
+ else:
144
+ L = loss.lipschitz(X_proc, _lip_coef, y=y_proc)
145
+ _cached_XtWX_weighted = None
146
+ if L <= 0:
147
+ L = 1.0
148
+ # Add smooth penalty Lipschitz contribution (e.g. l2 penalty gradient
149
+ # alpha*coef has Lipschitz constant alpha). Without this, the step
150
+ # size 1/L is too large, causing oscillation near the optimum.
151
+ _smooth_lip = _smooth_penalty_lipschitz(penalty)
152
+ if _smooth_lip > 0:
153
+ L = L + _smooth_lip
154
+ # For GLM losses with exp link (Poisson, etc.), mu at coef=0
155
+ # is ~1, but mu near the optimum ≈ y. Scale Lipschitz up by a
156
+ # geometric-mean factor to avoid oversized first steps that cause
157
+ # divergence on non-smooth penalties (scad, mcp, etc.).
158
+ # Logistic now uses iterate-dependent Lipschitz, so y-scaling applies.
159
+ # Gamma's expected Fisher Hessian X'X/n underestimates
160
+ # true curvature by ~mean(y), so y-scaling IS needed.
161
+ _skip_y_scaling = getattr(loss, '_lipschitz_uses_y', False)
162
+ _y_scale = 1.0 # default; overridden below for families that need it
163
+ if not _is_quadratic and not _skip_y_scaling:
164
+ _y_mean, _y_max = _abs_mean_max(y_proc, backend)
165
+ _y_scale = max(1.0, _y_mean, np.sqrt(_y_mean * _y_max))
166
+ if _y_scale > 1.0:
167
+ L = L * _y_scale
168
+
169
+ # Loss-specific Lipschitz safety factors (from loss class attributes)
170
+ _lip_safety = getattr(loss, '_lipschitz_safety', 1.0)
171
+ if _lip_safety > 1.0:
172
+ L = L * _lip_safety
173
+ # Additional safety for CV mode (from loss class attribute)
174
+ _lip_safety_cv = getattr(loss, '_lipschitz_safety_cv', _LIPSCHITZ_SAFETY_LOGISTIC_CV if cv_mode else 1.0)
175
+ if cv_mode and _lip_safety_cv > 1.0:
176
+ L = L * _lip_safety_cv
177
+ # Async GPU loop: skip backtracking, deferred checks.
178
+ # For non-smooth penalties (l1, elasticnet, scad, mcp, adaptive, group):
179
+ # - Quadratic losses (squared_error): Lipschitz is exact, fixed step is optimal
180
+ # - GLM losses: use 3x safety factor on Lipschitz, no backtracking
181
+ # Smooth penalties (l2, none) need backtracking for GLM losses.
182
+ n_samples = X_proc.shape[0]
183
+ _pen_name_lower = _penalty_name(penalty)
184
+ _non_smooth = _pen_name_lower not in ("none", "null", "l2", "")
185
+ _gpu_excluded = getattr(loss, '_gpu_loop_excluded', False) and not cv_mode
186
+ # Async GPU loop: skip backtracking, use fixed step size.
187
+ # For squared_error + non-smooth penalties, Lipschitz is exact → no backtracking needed.
188
+ # For GLM losses, only enabled in CV mode (backtracking needed for safety).
189
+ _use_gpu_loop = (
190
+ backend in ("torch", "cupy")
191
+ and _non_smooth
192
+ and (cv_mode or _is_quadratic)
193
+ and not _gpu_excluded
194
+ )
195
+ _is_gpu = backend in ("torch", "cupy")
196
+ _conv_interval = 3
197
+ _div_interval = 5
198
+ _lip_interval = 5
199
+ if _use_gpu_loop:
200
+ _conv_interval = 10
201
+ _div_interval = 25
202
+ _lip_interval = 25
203
+ _validate_sample_weight(sample_weight, X_proc.shape[0])
204
+
205
+ # Gram matrix optimization for squared_error on async GPU path only.
206
+ # Precompute X'X/n and X'y/n to avoid redundant X@coef per iteration.
207
+ _use_xtx = _is_quadratic and sample_weight is None and _use_gpu_loop
208
+ if _use_xtx:
209
+ _xp_mod = _get_xp(backend)
210
+ XtX = X_proc.T @ X_proc / n_samples
211
+ Xty = X_proc.T @ y_proc / n_samples
212
+ else:
213
+ XtX = None
214
+ Xty = None
215
+
216
+ iteration = -1 # default if max_iter=0
217
+
218
+ for iteration in range(max_iter):
219
+ coef_old = _copy_arr(coef)
220
+
221
+ # Compute gradient
222
+ if _use_xtx and XtX is not None:
223
+ # Gram matrix path: single matmul instead of X@coef + X.T@resid
224
+ # XtX = X'X/n, Xty = X'y/n, so grad = XtX @ w - Xty = X'(Xw-y)/n
225
+ grad = XtX @ y_k - Xty
226
+ q_yk_dev = loss.value(X_proc, y_proc, y_k)
227
+ elif sample_weight is not None:
228
+ q_yk_dev, grad = loss.fused_value_and_gradient(
229
+ X_proc, y_proc, y_k, sample_weight=sample_weight
230
+ )
231
+ else:
232
+ q_yk_dev, grad = loss.fused_value_and_gradient(X_proc, y_proc, y_k)
233
+
234
+ if _use_gpu_loop:
235
+ # -- GPU async path: all ops stay on device --
236
+ grad = _clip_grad_on_device(grad, coef_old, backend)
237
+
238
+ step = 1.0 / L
239
+
240
+ # Single proximal step -- no backtracking (L is conservative enough)
241
+ w_tilde = y_k - step * grad
242
+ coef = penalty.proximal(w_tilde, step, backend=backend)
243
+
244
+ # ALL safety checks deferred -- no per-iteration GPU->CPU sync.
245
+ # Finiteness + divergence + objective tracking batched together.
246
+ if iteration > 0 and (iteration < 20 or iteration % _div_interval == 0):
247
+ _obj_dev = loss.value(X_proc, y_proc, coef)
248
+ # Single D2H transfer: extract float, then check finiteness.
249
+ _obj_val_f = float(_to_numpy(_obj_dev))
250
+ _all_finite = np.isfinite(_obj_val_f)
251
+ if not _all_finite:
252
+ if _coef_best_fista is not None:
253
+ coef = _copy_arr(_coef_best_fista)
254
+ else:
255
+ coef = _zeros(n_features, backend, ref_tensor=X_proc)
256
+ y_k = _copy_arr(coef)
257
+ t_k = 1.0
258
+ L = L * 2.0
259
+ continue
260
+ # Track best objective (reuse _obj_val_f from finiteness check above)
261
+ _obj_val_f += _tracking_penalty_value(penalty, coef)
262
+ if _obj_val_f < _obj_best_fista:
263
+ _obj_best_fista = _obj_val_f
264
+ _coef_best_fista = _copy_arr(coef)
265
+ # Periodic Lipschitz recomputation (piggyback on same sync)
266
+ # Skip for quadratic losses -- Lipschitz is constant (spectral norm of X^T X).
267
+ # Interval matches CPU path for trajectory consistency.
268
+ if not _is_quadratic and iteration % _lip_interval == 0:
269
+ if sample_weight is not None and _cached_XtWX_weighted is not None:
270
+ # Use cached weighted Gram matrix (X and weights are constant)
271
+ _xp_lip = _get_xp(backend)
272
+ L_new = _to_float_scalar(_xp_lip.max(_xp_lip.diag(_cached_XtWX_weighted)))
273
+ else:
274
+ L_new = loss.lipschitz(X_proc, coef, y=y_proc)
275
+ if L_new > 0:
276
+ # Re-apply y-scaling (Lipschitz at current coef may not
277
+ # capture the y-dependent curvature scaling applied at init)
278
+ if _y_scale > 1.0:
279
+ L_new = L_new * _y_scale
280
+ _safety = getattr(loss, '_lipschitz_safety', 1.0)
281
+ L_new *= _safety
282
+ if _smooth_lip > 0:
283
+ L_new = L_new + _smooth_lip
284
+ if L_new > L:
285
+ L = L_new
286
+ else:
287
+ L = max(L * 0.8, L_new)
288
+
289
+
290
+ else:
291
+ # -- CPU/GPU path with backtracking (smooth penalties) --
292
+ # Use identical sync-based clipping for both CPU and GPU.
293
+ # (Backtracking already syncs every iteration for slack check,
294
+ # so on-device clipping has no performance benefit here.)
295
+ _gn_f, _coef_abs_f = _sync_scalars(
296
+ _norm2_dev(grad), _abs_sum_dev(coef_old), backend=backend)
297
+ _gmax = max(_coef_abs_f * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
298
+ if _gn_f > _gmax:
299
+ grad = grad * (_gmax / _gn_f)
300
+
301
+ step = 1.0 / L
302
+ _q_new_dev_last = None
303
+ for _bt in range(20):
304
+ w_tilde = y_k - step * grad
305
+ coef_new = penalty.proximal(w_tilde, step, backend=backend)
306
+
307
+ diff = coef_new - y_k
308
+ if sample_weight is not None:
309
+ q_new_dev, _ = loss.fused_value_and_gradient(
310
+ X_proc, y_proc, coef_new, sample_weight=sample_weight
311
+ )
312
+ else:
313
+ q_new_dev = loss.value(X_proc, y_proc, coef_new)
314
+ _q_new_dev_last = q_new_dev
315
+ bound_dev = q_yk_dev + _dot_dev(grad, diff) + 0.5 * L * _sum_sq_dev(diff)
316
+ slack_dev = bound_dev + _SLACK_TOLERANCE - q_new_dev
317
+ _armijo_ok = _to_float_scalar(slack_dev) >= 0
318
+ if _armijo_ok:
319
+ break
320
+ L *= 1.5
321
+ step = 1.0 / L
322
+
323
+ coef = coef_new
324
+
325
+ # Finiteness check
326
+ if not _is_quadratic:
327
+ _coef_norm_dev = _norm2_dev(coef)
328
+ _finite_ok = np.isfinite(float(_coef_norm_dev))
329
+ if not _finite_ok:
330
+ if _coef_best_fista is not None:
331
+ coef = _copy_arr(_coef_best_fista)
332
+ y_k = _copy_arr(coef)
333
+ t_k = 1.0
334
+ L = L * 2.0
335
+ continue
336
+
337
+ # Divergence detection
338
+ if not _is_quadratic and iteration > 0:
339
+ _need_norm_check = (iteration > 10)
340
+ if _q_new_dev_last is not None:
341
+ _obj_dev = _q_new_dev_last
342
+ _q_new_dev_last = None
343
+ else:
344
+ if sample_weight is not None:
345
+ _obj_dev, _ = loss.fused_value_and_gradient(
346
+ X_proc, y_proc, coef, sample_weight=sample_weight
347
+ )
348
+ else:
349
+ _obj_dev = loss.value(X_proc, y_proc, coef)
350
+ # Batched sync: objective + coef norm in one transfer
351
+ if _need_norm_check:
352
+ _obj_val_f, _coef_norm_f = _sync_scalars(
353
+ _obj_dev, _norm2_dev(coef), backend=backend
354
+ )
355
+ else:
356
+ _obj_val_f = float(_to_numpy(_obj_dev))
357
+ _coef_norm_f = 0.0
358
+ _obj_val_f += _tracking_penalty_value(penalty, coef)
359
+ _diverged_f = False
360
+ if not np.isfinite(_obj_val_f):
361
+ _diverged_f = True
362
+ elif _obj_best_fista > 1e-8:
363
+ _diverged_f = _obj_val_f > _obj_best_fista * 10.0 + 1e-8
364
+ else:
365
+ _diverged_f = _obj_val_f > _obj_best_fista + max(abs(_obj_best_fista) * 10.0, 1.0)
366
+ if not _diverged_f and _need_norm_check:
367
+ if _coef_norm_f > _DIVERGE_COEF_NORM_CAP:
368
+ _diverged_f = True
369
+ if _diverged_f:
370
+ if _coef_best_fista is not None:
371
+ coef = _copy_arr(_coef_best_fista)
372
+ else:
373
+ coef = _zeros(n_features, backend, ref_tensor=X_proc)
374
+ y_k = _copy_arr(coef)
375
+ t_k = 1.0
376
+ L = L * 2.0
377
+ continue
378
+ elif _obj_val_f < _obj_best_fista:
379
+ _obj_best_fista = _obj_val_f
380
+ _coef_best_fista = _copy_arr(coef)
381
+
382
+ # Periodic Lipschitz recomputation
383
+ # Skip if coefficients haven't changed much (Lipschitz is stable)
384
+ if not _is_quadratic and iteration > 0 and iteration % 5 == 0:
385
+ # Batch both norms into a single GPU->CPU transfer
386
+ _coef_change, _coef_norm = _sync_scalars(
387
+ _norm2_dev(coef - coef_old), _norm2_dev(coef), backend=backend)
388
+ _relative_change = _coef_change / max(_coef_norm, 1e-10)
389
+ if _relative_change > 1e-3: # Only recompute if coefficients changed significantly
390
+ L_new = _call_with_weight(loss.lipschitz, X_proc, coef, y=y_proc, sample_weight=sample_weight)
391
+ # Safety factors from loss class
392
+ _lip_safety_recomp = getattr(loss, '_lipschitz_safety', 1.0)
393
+ if _lip_safety_recomp > 1.0:
394
+ L_new = L_new * _lip_safety_recomp
395
+ if _smooth_lip > 0:
396
+ L_new = L_new + _smooth_lip
397
+ if L_new > L:
398
+ L = L_new
399
+ else:
400
+ L = max(L * 0.8, L_new)
401
+
402
+ # Momentum update -- all backends
403
+ if _skip_momentum:
404
+ # No momentum (e.g. inverse_gaussian): just copy coef
405
+ y_k = _copy_arr(coef)
406
+ elif _momentum_beta_cap is not None:
407
+ # Conservative momentum with capped beta
408
+ y_k, t_k = _nesterov_update(coef, coef_old, t_k, beta_cap=_momentum_beta_cap)
409
+ else:
410
+ y_k, t_k = _nesterov_update(coef, coef_old, t_k)
411
+
412
+ # Convergence check -- deferred for GPU, every iteration for CPU
413
+ if _is_gpu:
414
+ if iteration < 20 or iteration % _conv_interval == 0:
415
+ _conv_dev = _abs_sum_dev(coef - coef_old)
416
+ if _to_float_scalar(_conv_dev) < tol:
417
+ break
418
+ else:
419
+ _conv_dev = _abs_sum_dev(coef - coef_old)
420
+ if float(_conv_dev) < tol:
421
+ break
422
+
423
+ # Return best iterate if available
424
+ if _coef_best_fista is not None:
425
+ coef = _copy_arr(_coef_best_fista)
426
+
427
+ n_iter = iteration + 1
428
+ if n_iter >= max_iter:
429
+ warnings.warn(
430
+ f"fista_solver did not converge within {max_iter} iterations "
431
+ f"(loss={getattr(loss, 'name', '?')}, penalty={getattr(penalty, 'name', '?')}). "
432
+ f"Consider increasing max_iter or using a different solver (newton, lbfgs, irls).",
433
+ ConvergenceWarning,
434
+ stacklevel=2,
435
+ )
436
+ return coef, n_iter