statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,541 @@
1
+ """Fused LLA+FISTA solver for SCAD/MCP over a continuation path.
2
+
3
+ Runs the entire continuation -> LLA -> FISTA loop in one tight function,
4
+ eliminating per-call overhead (backend detect, preprocess, Lipschitz
5
+ recompute, array allocation) that accumulates over 300+ fista_solver calls.
6
+ """
7
+
8
+ __all__ = ["fista_lla_path"]
9
+
10
+ import copy
11
+ import numpy as np
12
+
13
+ from statgpu.backends import _resolve_backend, _to_numpy
14
+ from statgpu.backends._utils import _to_float_scalar, xp_ones
15
+ from statgpu.backends._array_ops import (
16
+ _abs_sum_dev,
17
+ _clip_grad_on_device,
18
+ _copy_arr,
19
+ _norm2_dev,
20
+ _sync_scalars,
21
+ _zeros,
22
+ )
23
+ from statgpu.penalties._categories import NONSMOOTH as _NONSMOOTH_ALL
24
+ from statgpu.penalties._adaptive_l1 import AdaptiveL1Penalty
25
+ from ._constants import (
26
+ _DIVERGE_COEF_NORM_CAP,
27
+ _GRAD_CLIP_COEF_FACTOR,
28
+ _GRAD_CLIP_ABS_FLOOR,
29
+ _GRAD_CLIP_MAX,
30
+ )
31
+ from ._utils import (
32
+ _nesterov_momentum,
33
+ _validate_sample_weight,
34
+ )
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Fused proximal kernels for squared_error + AdaptiveL1 (SCAD/MCP via LLA)
38
+ # ---------------------------------------------------------------------------
39
+ # Pre-computes XtX, Xty to avoid redundant matmul; fuses element-wise ops;
40
+ # defers GPU->CPU syncs for convergence.
41
+
42
+ _SQERR_PROXIMAL_TORCH = None
43
+ _SQERR_PROXIMAL_CUPY = None
44
+
45
+
46
+ def _get_sqerr_proximal_torch():
47
+ global _SQERR_PROXIMAL_TORCH
48
+ if _SQERR_PROXIMAL_TORCH is None:
49
+ import torch
50
+ # torch.compile requires CUDA capability >= 7.0 (Triton).
51
+ # Fall back to JIT script for older GPUs (P100 = 6.0).
52
+ _cap = torch.cuda.get_device_capability()[0] if torch.cuda.is_available() else 0
53
+ if _cap >= 7:
54
+ try:
55
+ @torch.compile(mode='reduce-overhead', backend='inductor')
56
+ def _fused_update(y_current, grad, step, thresh, coef_old, beta):
57
+ w = y_current - step * grad
58
+ abs_w = w.abs()
59
+ sign_w = w.sign()
60
+ coef_new = sign_w * (abs_w - thresh).clamp(min=0.0)
61
+ y_k = coef_new + beta * (coef_new - coef_old)
62
+ return coef_new, y_k
63
+ _SQERR_PROXIMAL_TORCH = _fused_update
64
+ except (RuntimeError, TypeError):
65
+ pass
66
+ if _SQERR_PROXIMAL_TORCH is None:
67
+ def _fused_update_eager(y_current, grad, step, thresh, coef_old, beta):
68
+ w = y_current - step * grad
69
+ abs_w = w.abs()
70
+ sign_w = w.sign()
71
+ coef_new = sign_w * (abs_w - thresh).clamp(min=0.0)
72
+ y_k = coef_new + beta * (coef_new - coef_old)
73
+ return coef_new, y_k
74
+ _SQERR_PROXIMAL_TORCH = _fused_update_eager
75
+ return _SQERR_PROXIMAL_TORCH
76
+
77
+
78
+ def _get_sqerr_proximal_cupy():
79
+ global _SQERR_PROXIMAL_CUPY
80
+ if _SQERR_PROXIMAL_CUPY is None:
81
+ import cupy as cp
82
+ _SQERR_PROXIMAL_CUPY = cp.ElementwiseKernel(
83
+ 'T y_current, T grad, T step, T thresh, T coef_old, T beta',
84
+ 'T coef_new, T y_k',
85
+ '''
86
+ T w = y_current - step * grad;
87
+ T abs_w = abs(w);
88
+ T sign_w = (w > 0) ? 1 : ((w < 0) ? -1 : 0);
89
+ coef_new = (abs_w > thresh) ? sign_w * (abs_w - thresh) : 0;
90
+ y_k = coef_new + beta * (coef_new - coef_old);
91
+ ''',
92
+ 'sqerr_proximal_fused',
93
+ )
94
+ return _SQERR_PROXIMAL_CUPY
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Main solver
99
+ # ---------------------------------------------------------------------------
100
+
101
+
102
+ def fista_lla_path(
103
+ loss,
104
+ scad_penalty,
105
+ X,
106
+ y,
107
+ alpha_path,
108
+ max_lla_per_step=6,
109
+ lla_tol=1e-6,
110
+ max_iter=1000,
111
+ tol=1e-4,
112
+ fit_intercept=True,
113
+ sample_weight=None,
114
+ lla_penalty_factory=None,
115
+ init_coef=None,
116
+ init_intercept=None,
117
+ return_path=False,
118
+ ):
119
+ """Fused LLA+FISTA solver for SCAD/MCP over a continuation path.
120
+
121
+ Runs the entire continuation -> LLA -> FISTA loop in one tight function,
122
+ eliminating per-call overhead (backend detect, preprocess, Lipschitz
123
+ recompute, array allocation) that accumulates over 300+ fista_solver calls.
124
+
125
+ Parameters
126
+ ----------
127
+ loss : GLMLoss
128
+ scad_penalty : SCADPenalty or MCPPenalty
129
+ Penalty object; its .alpha will be set along the path.
130
+ X, y : array (pre-centered if fit_intercept=True)
131
+ alpha_path : array of alpha values (descending, geomspace)
132
+ max_lla_per_step : int
133
+ lla_tol : float
134
+ max_iter : int or list[int]
135
+ FISTA iteration limit. If a list, one value per continuation step.
136
+ tol : float
137
+ fit_intercept : bool
138
+ sample_weight : array or None
139
+ init_coef : array or None
140
+ Warm-start coefficients (without intercept). If provided, they are
141
+ injected only at the final target-alpha continuation step.
142
+ init_intercept : float or None
143
+ Warm-start intercept value.
144
+ return_path : bool, default=False
145
+ When True, also return coefficients/intercepts after each continuation
146
+ alpha. The default keeps the historical 3-tuple return value.
147
+
148
+ Returns
149
+ -------
150
+ coef : array (p,)
151
+ intercept : float
152
+ total_iter : int
153
+ """
154
+ backend = _resolve_backend("auto", X)
155
+ if backend == "torch":
156
+ import torch as xp
157
+ torch = xp
158
+ x_dtype = X.dtype if getattr(X, "is_floating_point", lambda: False)() else torch.float64
159
+ y_dtype = y.dtype if getattr(y, "is_floating_point", lambda: False)() else torch.float64
160
+ common_dtype = torch.promote_types(x_dtype, y_dtype)
161
+ X = X.to(dtype=common_dtype)
162
+ y = torch.as_tensor(y, device=X.device, dtype=common_dtype)
163
+ elif backend == "cupy":
164
+ import cupy as xp
165
+ else:
166
+ xp = np
167
+ X_proc, y_proc = loss.preprocess(X, y)
168
+ _is_quadratic = getattr(loss, '_is_quadratic', False)
169
+ _no_momentum = getattr(loss, '_skip_momentum', False)
170
+ _non_smooth_pen_lla = getattr(scad_penalty, 'name', '') in _NONSMOOTH_ALL
171
+ _momentum_beta_cap = getattr(loss, '_momentum_beta_cap', None)
172
+ _conservative_momentum_lla = (
173
+ _momentum_beta_cap is not None
174
+ or (getattr(loss, '_conservative_momentum_with_nonsmooth', False)
175
+ and _non_smooth_pen_lla)
176
+ )
177
+
178
+ n_samples, n_features = X_proc.shape
179
+ _validate_sample_weight(sample_weight, n_samples)
180
+
181
+ # --- Intercept handling ---
182
+ # For squared_error (identity link): centering X, y is exact.
183
+ # For GLM losses (log/logit link): centering is WRONG -- it changes
184
+ # the objective. Instead, augment X with a ones column so the
185
+ # intercept is part of the coefficient vector.
186
+ _augment_intercept = fit_intercept and not _is_quadratic
187
+ if _augment_intercept:
188
+ # Augment X with a column of ones
189
+ ones_col = xp_ones((X.shape[0], 1), dtype=X.dtype, xp=xp, ref_arr=X)
190
+ X_c = xp.concatenate([X, ones_col], axis=1)
191
+ y_c = y
192
+ n_aug = n_features + 1
193
+ elif fit_intercept:
194
+ # squared_error: centering is exact for identity link
195
+ X_mean = xp.mean(X, axis=0)
196
+ y_mean = xp.mean(y)
197
+ X_c = X - X_mean
198
+ y_c = y - y_mean
199
+ n_aug = n_features
200
+ else:
201
+ X_c = X
202
+ y_c = y
203
+ n_aug = n_features
204
+
205
+ # Precompute Lipschitz using loss-specific method.
206
+ # Pass zero coef (global bound) -- not all losses handle coef=None.
207
+ _zero_coef_lla = _zeros(n_aug, backend, ref_tensor=X_c)
208
+ L_base = loss.lipschitz(X_c, _zero_coef_lla, y=y_c)
209
+ # Precompute XtX only for squared_error fast path (skip for GLM losses)
210
+ XtX = X_c.T @ X_c if _is_quadratic else None
211
+ if L_base <= 0:
212
+ L_base = 1.0
213
+
214
+ # Apply loss-specific Lipschitz safety factor (e.g. NB=2x, gamma=3x)
215
+ _lipschitz_safety = getattr(loss, '_lipschitz_safety', 1.0)
216
+ if _lipschitz_safety > 1.0:
217
+ L_base = L_base * _lipschitz_safety
218
+
219
+ # Y-scaling for exp-link families (Poisson, Gamma, etc.).
220
+ # At coef=0, mu~1, but near the optimum mu~y. The Hessian scales
221
+ # with mu, so L_base underestimates by up to max(y).
222
+ # Cap at 10x -- periodic Lipschitz recomputation corrects any remaining
223
+ # underestimate during the FISTA inner loop.
224
+ _skip_y_scaling = getattr(loss, '_lipschitz_uses_y', False)
225
+ _y_lipschitz_scale = 1.0
226
+ if not _is_quadratic and not _skip_y_scaling:
227
+ _y_arr = _to_numpy(y_c)
228
+ _y_abs = np.abs(_y_arr)
229
+ _y_mean = float(np.mean(_y_abs))
230
+ _y_max = float(np.max(_y_abs))
231
+ _y_lipschitz_scale = min(10.0, max(1.0, np.sqrt(_y_mean * _y_max)))
232
+ if _y_lipschitz_scale > 1.0:
233
+ L_base = L_base * _y_lipschitz_scale
234
+
235
+ def _zeros_coef():
236
+ return _zeros(n_aug, backend, ref_tensor=X_c)
237
+
238
+ def _warm_start_coef():
239
+ if init_coef is None:
240
+ return None
241
+ if backend == "torch":
242
+ import torch
243
+ _init = torch.as_tensor(init_coef, device=X_c.device, dtype=X_c.dtype)
244
+ if _augment_intercept and _init.shape[0] == n_features:
245
+ return torch.cat([
246
+ _init,
247
+ torch.tensor(
248
+ [0.0 if init_intercept is None else init_intercept],
249
+ device=X_c.device,
250
+ dtype=X_c.dtype,
251
+ ),
252
+ ])
253
+ return _init.clone()
254
+ if backend == "cupy":
255
+ import cupy as cp
256
+ _init = cp.asarray(init_coef, dtype=X_c.dtype)
257
+ if _augment_intercept and _init.shape[0] == n_features:
258
+ return cp.concatenate([
259
+ _init,
260
+ cp.array([0.0 if init_intercept is None else init_intercept], dtype=X_c.dtype),
261
+ ])
262
+ return _init.copy()
263
+ _init = np.asarray(init_coef, dtype=np.float64)
264
+ if _augment_intercept and _init.shape[0] == n_features:
265
+ return np.concatenate([
266
+ _init,
267
+ [0.0 if init_intercept is None else float(init_intercept)],
268
+ ])
269
+ return _init.copy()
270
+
271
+ # Keep the continuation path deterministic from zero. CV warm-starts are
272
+ # injected only at the target-alpha step, otherwise SCAD/MCP LLA weights can
273
+ # follow a different local trajectory for NB/Tweedie-like losses.
274
+ coef = _zeros_coef()
275
+ warm_coef = _warm_start_coef()
276
+
277
+ total_iter = 0
278
+ inner_pen = AdaptiveL1Penalty(alpha=1.0)
279
+ path_records = [] if return_path else None
280
+
281
+ def _split_current_coef(current_coef):
282
+ coef_all = np.asarray(_to_numpy(current_coef), dtype=np.float64).ravel()
283
+ if _augment_intercept:
284
+ return coef_all[:n_features].copy(), float(coef_all[n_features])
285
+ if fit_intercept:
286
+ X_mean_np = np.asarray(_to_numpy(X_mean), dtype=np.float64).ravel()
287
+ y_mean_np = float(_to_numpy(y_mean))
288
+ return coef_all.copy(), float(y_mean_np - X_mean_np @ coef_all)
289
+ return coef_all.copy(), 0.0
290
+
291
+ def _record_path_alpha(alpha_value):
292
+ if path_records is None:
293
+ return
294
+ coef_rec, intercept_rec = _split_current_coef(coef)
295
+ path_records.append({
296
+ "alpha": float(alpha_value),
297
+ "coef": coef_rec,
298
+ "intercept": float(intercept_rec),
299
+ "n_iter": int(total_iter),
300
+ })
301
+
302
+ # For squared_error + GPU: fully inlined fused loop.
303
+ # Uses torch.compile for torch, ElementwiseKernel for cupy.
304
+ # Must gate on sample_weight is None because the fused path uses
305
+ # unweighted Gram matrix (XtX, Xty) which is incorrect for weighted data.
306
+ if _is_quadratic and backend in ("torch", "cupy") and sample_weight is None:
307
+ Xty = X_c.T @ y_c
308
+
309
+ # Get fused proximal kernel
310
+ if backend == "torch":
311
+ _fused = _get_sqerr_proximal_torch()
312
+ coef_old = coef.clone()
313
+ y_k = coef.clone()
314
+ else:
315
+ _fused = _get_sqerr_proximal_cupy()
316
+ coef_old = coef.copy()
317
+ y_k = coef.copy()
318
+
319
+ step = 1.0 / L_base
320
+ t_k = 1.0
321
+
322
+ for _cont_i, cont_alpha in enumerate(alpha_path):
323
+ # Create a copy with the continuation alpha to avoid mutating
324
+ # the shared penalty object (thread-safety for future parallel CV).
325
+ _pen_step = copy.copy(scad_penalty)
326
+ _pen_step.alpha = float(cont_alpha)
327
+ _mi = max_iter[_cont_i] if isinstance(max_iter, (list, tuple)) else max_iter
328
+ if warm_coef is not None and _cont_i == len(alpha_path) - 1:
329
+ coef = _copy_arr(warm_coef)
330
+ for _lla_i in range(max_lla_per_step):
331
+ # lla_weights() is now backend-aware -- stays on device
332
+ lla_w = _pen_step.lla_weights(coef)
333
+ thresh = lla_w * step # stays on device
334
+
335
+ # Save coef for LLA convergence check (on device)
336
+ coef_before_lla = _copy_arr(coef)
337
+
338
+ # Reset momentum for new LLA step
339
+ t_k = 1.0
340
+ coef_old = _copy_arr(coef)
341
+ y_k = _copy_arr(coef)
342
+
343
+ # FISTA inner solve (inlined, fused proximal+momentum)
344
+ _conv_interval = 20 # check convergence every N iters (reduced GPU sync)
345
+ iteration = -1 # guard against _mi=0 causing UnboundLocalError
346
+ for iteration in range(_mi):
347
+ coef_old = _copy_arr(coef)
348
+
349
+ # Gradient: grad = (XtX @ y_k - Xty) / n
350
+ grad = (XtX @ y_k - Xty) / n_samples
351
+
352
+ # Clip gradients
353
+ if iteration % 10 == 0:
354
+ grad = _clip_grad_on_device(grad, coef_old, backend)
355
+
356
+ # Compute momentum beta BEFORE proximal so fused kernel does both
357
+ if _no_momentum:
358
+ beta_mom = 0.0
359
+ else:
360
+ beta_mom, t_k = _nesterov_momentum(t_k)
361
+
362
+ # Fused proximal + momentum in one kernel call. The gradient
363
+ # is evaluated at y_k, so y_k is the proximal center.
364
+ coef, y_k = _fused(y_k, grad, step, thresh, coef_old, beta_mom)
365
+
366
+ # Convergence check (device-side, minimal sync)
367
+ if iteration < 20 or iteration % _conv_interval == 0:
368
+ coef_diff_dev = _abs_sum_dev(coef - coef_old)
369
+ _cdf = _to_float_scalar(coef_diff_dev)
370
+ converged = _cdf < tol
371
+ diverged = (not np.isfinite(_cdf))
372
+ if converged:
373
+ break
374
+ if diverged:
375
+ coef = _copy_arr(coef_old)
376
+ break
377
+
378
+ total_iter += iteration + 1
379
+
380
+ # LLA convergence check (device-side, minimal sync)
381
+ delta_dev = _abs_sum_dev(coef - coef_before_lla)
382
+ if _to_float_scalar(delta_dev) < lla_tol:
383
+ break
384
+ _record_path_alpha(cont_alpha)
385
+ else:
386
+ # Pre-compute XtX and Xty for squared_error (avoids redundant matmuls).
387
+ # Must gate on sample_weight is None because XtX/Xty are unweighted.
388
+ _use_xtx = _is_quadratic and backend == "numpy" and sample_weight is None
389
+ if _use_xtx:
390
+ Xty = X_c.T @ y_c
391
+
392
+ for _cont_i, cont_alpha in enumerate(alpha_path):
393
+ # Create a copy with the continuation alpha to avoid mutating
394
+ # the shared penalty object (thread-safety for future parallel CV).
395
+ _pen_step = copy.copy(scad_penalty)
396
+ _pen_step.alpha = float(cont_alpha)
397
+ _mi = max_iter[_cont_i] if isinstance(max_iter, (list, tuple)) else max_iter
398
+ if warm_coef is not None and _cont_i == len(alpha_path) - 1:
399
+ coef = _copy_arr(warm_coef)
400
+
401
+ for _lla_i in range(max_lla_per_step):
402
+ # lla_weights() is now backend-aware -- stays on device
403
+ if _augment_intercept:
404
+ lla_w_feat = _pen_step.lla_weights(coef[:n_features])
405
+ # Append 0.0 for intercept on device
406
+ _zero_append = _zeros(1, backend, ref_tensor=coef)
407
+ lla_w = xp.concatenate([lla_w_feat, _zero_append])
408
+ else:
409
+ lla_w = _pen_step.lla_weights(coef)
410
+ if lla_penalty_factory is not None:
411
+ # lla_penalty_factory expects numpy; convert only if needed
412
+ lla_w_np = _to_numpy(lla_w) if type(lla_w).__module__ != "numpy" else lla_w
413
+ inner_pen = lla_penalty_factory(lla_w_np)
414
+ else:
415
+ inner_pen._weights = lla_w
416
+
417
+ # Save coef for LLA convergence check (on device)
418
+ coef_before_lla = _copy_arr(coef)
419
+
420
+ # --- FISTA inner solve (fixed-step, no backtracking) ---
421
+ y_k = _copy_arr(coef)
422
+ t_k = 1.0
423
+ L = L_base
424
+
425
+ # Get fused proximal+momentum kernel for GPU paths
426
+ if backend == "torch":
427
+ _fused_update = _get_sqerr_proximal_torch()
428
+ elif backend == "cupy":
429
+ _fused_update = _get_sqerr_proximal_cupy()
430
+ else:
431
+ _fused_update = None
432
+ step = 1.0 / L
433
+
434
+ # Pre-compute device-side tolerance for convergence check
435
+ if backend != "numpy":
436
+ _tol_dev = xp.asarray(tol)
437
+
438
+ # --- Async inner loop: skip backtracking, use fixed step ---
439
+ # For LLA, the Lipschitz constant L is pre-computed and stable.
440
+ # Backtracking is unnecessary — use fixed step 1/L.
441
+ # This eliminates per-iteration GPU→CPU syncs.
442
+ for iteration in range(_mi):
443
+ coef_old = _copy_arr(coef)
444
+
445
+ if _use_xtx:
446
+ grad = (XtX @ y_k - Xty) / n_samples
447
+ else:
448
+ if sample_weight is not None:
449
+ _, grad = loss.fused_value_and_gradient(
450
+ X_c, y_c, y_k, sample_weight=sample_weight,
451
+ )
452
+ else:
453
+ _, grad = loss.fused_value_and_gradient(X_c, y_c, y_k)
454
+
455
+ # Clip gradients (device-side, every 10 iterations)
456
+ if backend == "numpy" or iteration % 10 == 0:
457
+ _gn_dev = _norm2_dev(grad)
458
+ _gsum = _abs_sum_dev(coef_old) * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR
459
+ if backend == "torch":
460
+ _gmax_dev = xp.clamp(_gsum, min=_GRAD_CLIP_MAX)
461
+ else:
462
+ _gmax_dev = xp.maximum(_gsum, _GRAD_CLIP_MAX)
463
+ _gn_f, _gmax_f = _sync_scalars(_gn_dev, _gmax_dev, backend=backend)
464
+ if _gn_f > _gmax_f:
465
+ grad = grad * (_gmax_dev / _gn_dev)
466
+
467
+
468
+ # Compute momentum beta before fused update
469
+ if _no_momentum:
470
+ beta_mom = 0.0
471
+ elif _conservative_momentum_lla:
472
+ beta_mom, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
473
+ else:
474
+ beta_mom, t_k = _nesterov_momentum(t_k)
475
+
476
+ # Fused proximal + momentum: single kernel launch on GPU
477
+ # Combines: w_tilde = y_k - step*grad
478
+ # coef = proximal(w_tilde, step) [weighted soft-threshold]
479
+ # y_k = coef + beta * (coef - coef_old)
480
+ # Reduces 3 kernel launches to 1.
481
+ if _fused_update is not None and backend != "numpy":
482
+ # Ensure thresh is on the correct device
483
+ _w = inner_pen._weights
484
+ if isinstance(_w, np.ndarray):
485
+ _w = xp.asarray(_w, dtype=coef.dtype)
486
+ thresh = _w * inner_pen.alpha * step
487
+ coef, y_k = _fused_update(y_k, grad, step, thresh, coef_old, beta_mom)
488
+ else:
489
+ w_tilde = y_k - step * grad
490
+ coef = inner_pen.proximal(w_tilde, step, backend=backend)
491
+ y_k = coef + beta_mom * (coef - coef_old)
492
+
493
+ # Convergence (device-side comparison, only D2H 1 bool)
494
+ if backend == "numpy" or iteration < 20 or iteration % 5 == 0:
495
+ _conv_dev = _abs_sum_dev(coef - coef_old)
496
+ if backend != "numpy":
497
+ if bool(_to_numpy(_conv_dev < _tol_dev)):
498
+ break
499
+ else:
500
+ if float(_to_numpy(_conv_dev)) < tol:
501
+ break
502
+
503
+ # Periodic Lipschitz recomputation -- corrects stale L
504
+ # as coef moves away from zero.
505
+ if not _is_quadratic and iteration > 0 and iteration % 20 == 0:
506
+ L_new = loss.lipschitz(X_c, coef, y=y_c)
507
+ if _y_lipschitz_scale > 1.0:
508
+ L_new = L_new * _y_lipschitz_scale
509
+ if L_new > L * 1.5 or L_new < L / 1.5:
510
+ L = max(L_new, L_base * 0.1)
511
+ step = 1.0 / L
512
+
513
+ total_iter += 1
514
+ # --- end FISTA ---
515
+
516
+ # LLA convergence (on device, single sync for scalar)
517
+ delta = float(_to_numpy(_abs_sum_dev(coef - coef_before_lla)))
518
+ if delta < lla_tol:
519
+ break
520
+ _record_path_alpha(cont_alpha)
521
+
522
+ # Extract coef and intercept
523
+ coef_np, intercept = _split_current_coef(coef)
524
+
525
+ if return_path:
526
+ if path_records:
527
+ path = {
528
+ "alpha": np.asarray([r["alpha"] for r in path_records], dtype=np.float64),
529
+ "coef": np.vstack([r["coef"] for r in path_records]).astype(np.float64, copy=False),
530
+ "intercept": np.asarray([r["intercept"] for r in path_records], dtype=np.float64),
531
+ "n_iter": np.asarray([r["n_iter"] for r in path_records], dtype=np.int64),
532
+ }
533
+ else:
534
+ path = {
535
+ "alpha": np.empty(0, dtype=np.float64),
536
+ "coef": np.empty((0, n_features), dtype=np.float64),
537
+ "intercept": np.empty(0, dtype=np.float64),
538
+ "n_iter": np.empty(0, dtype=np.int64),
539
+ }
540
+ return coef_np, intercept, total_iter, path
541
+ return coef_np, intercept, total_iter