statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2699 @@
1
+ """
2
+ Unified cross-validated penalized GLM estimator.
3
+
4
+ Supports all GLM loss functions (squared_error, logistic, poisson, gamma,
5
+ inverse_gaussian, negative_binomial, tweedie) with all penalty types
6
+ (l1, l2, elasticnet, scad, mcp, adaptive_l1, group_lasso).
7
+
8
+ Optimizations:
9
+ - Warm-start across alpha values (descending order)
10
+ - Batch eigendecomposition for squared_error + l2 (CPU/CuPy/Torch)
11
+ - Precomputed loss function and cached validation data per fold
12
+ - Minimal D2H transfers
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ __all__ = ["PenalizedGLM_CV"]
18
+
19
+ import logging
20
+ import warnings
21
+
22
+ logger = logging.getLogger(__name__)
23
+ from typing import Optional, Union
24
+
25
+ import numpy as np
26
+
27
+ from statgpu._config import Device
28
+ from statgpu.backends import _to_numpy
29
+ from statgpu.backends._array_ops import _copy_arr, _zeros, _xp_zeros, _soft_threshold
30
+ from statgpu.backends._utils import _to_float_scalar
31
+ from statgpu.cross_validation._base import CVEstimatorBase, kfold_indices
32
+ from statgpu.solvers._utils import _nesterov_momentum
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Numerical constants for GLM loss computation (shared across CV paths)
37
+ # ---------------------------------------------------------------------------
38
+
39
+ # Eta clipping bounds (prevents overflow in exp/link functions)
40
+ _ETA_CLIP_STANDARD = 30.0 # Poisson, Gamma, NB, InvGauss
41
+ _ETA_CLIP_TWEEDIE = 50.0 # Tweedie (wider range for mu^p stability)
42
+ _ETA_CLIP_LOGISTIC = 500.0 # Logistic (sigmoid saturates, safe range)
43
+
44
+ # Mu clipping bounds (prevents division by zero / log(0))
45
+ _MU_LO = 1e-10 # Standard lower bound for mu
46
+ _MU_LO_TWEEDIE = 1e-3 # Tweedie lower bound
47
+ _MU_HI_TWEEDIE = 1e4 # Tweedie upper bound
48
+ _MU_LO_INVGAUSS = 5e-2 # Inverse Gaussian lower bound
49
+ _MU_HI_INVGAUSS = 1e3 # Inverse Gaussian upper bound
50
+ _MU_LO_NB = 1e-300 # Negative binomial lower bound
51
+
52
+ # Default loss parameters (must match loss object defaults)
53
+ _NB_ALPHA_DEFAULT = 1.0 # NegativeBinomialLoss default alpha
54
+ _TWEEDIE_POWER_DEFAULT = 1.5 # TweedieLoss default power
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # CV solver tuning constants
58
+ # ---------------------------------------------------------------------------
59
+
60
+ # Eigenvalue floor (prevents division by zero in Lipschitz computation)
61
+ _EIGVAL_FLOOR = 1e-15
62
+
63
+ # FISTA iteration caps for CV (lower than full fit to keep CV fast)
64
+ _FISTA_MAX_ITER_CV = 400 # Default max FISTA iterations per alpha in CV
65
+ _FISTA_MAX_ITER_CV_SMALL = 600 # For small problems (n*p < _SMALL_PROBLEM_THRESHOLD)
66
+
67
+ # Convergence check intervals (sync cost vs responsiveness tradeoff)
68
+ _CONV_INTERVAL_CV_DEFAULT = 200 # Default convergence check interval
69
+ _CONV_INTERVAL_CV_TIGHT = 30 # Tighter interval for first few alphas
70
+ _CONV_INTERVAL_CV_FOLD = 50 # Per-fold convergence interval
71
+ _CONV_INTERVAL_CV_PATH = 25 # Path-based convergence interval
72
+ _CONV_INTERVAL_CV_NUMPY = 10 # Numpy path (no sync cost)
73
+
74
+ # Problem size thresholds
75
+ _SMALL_PROBLEM_THRESHOLD = 200_000 # n*p below this = "small problem"
76
+ _GPU_BREAK_EVEN_THRESHOLD = 100_000_000 # CV work below this = CPU faster
77
+
78
+ # IRLS deviance tolerance constants
79
+ _IRLS_DEV_TOL_REL = 1e-10 # Relative deviance tolerance
80
+ _IRLS_DEV_TOL_ABS = 1e-6 # Absolute deviance tolerance floor
81
+
82
+
83
+ class ApproximateCVWarning(UserWarning):
84
+ """Warning emitted when approximate two-stage CV screening is enabled."""
85
+
86
+
87
+ def _is_uniform_weight(sample_weight) -> bool:
88
+ """Check if sample_weight is uniform (all elements equal) or None."""
89
+ if sample_weight is None:
90
+ return True
91
+ sw_np = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
92
+ return not sw_np.size or np.allclose(sw_np, sw_np[0])
93
+
94
+
95
+ def _device_to_name(device):
96
+ if isinstance(device, Device):
97
+ return device.value
98
+ return str(device).lower()
99
+
100
+
101
+ def _slice_rows(arr, idx):
102
+ """Slice rows with backend-native indices when arr lives on GPU."""
103
+ mod = type(arr).__module__
104
+ if mod.startswith("cupy"):
105
+ import cupy as cp
106
+ return arr[cp.asarray(idx)]
107
+ if mod.startswith("torch"):
108
+ import torch
109
+ return arr[torch.as_tensor(idx, dtype=torch.long, device=arr.device)]
110
+ try:
111
+ return arr[idx]
112
+ except TypeError:
113
+ return np.asarray(arr)[idx]
114
+
115
+
116
+ def _nanargmin_prefer_larger_alpha(scores, alpha_grid, rel_tol=1e-10, abs_tol=1e-12):
117
+ """Select min score with deterministic tie-break toward stronger regularization."""
118
+ scores = np.asarray(scores, dtype=np.float64)
119
+ alpha_grid = np.asarray(alpha_grid, dtype=np.float64)
120
+ finite = np.isfinite(scores)
121
+ if not np.any(finite):
122
+ # All scores are NaN/Inf — fall back to first alpha (strongest regularization)
123
+ warnings.warn("All CV scores are NaN/Inf; returning first alpha.", stacklevel=2)
124
+ return 0
125
+ best = float(np.nanmin(scores))
126
+ tol = max(float(abs_tol), abs(best) * float(rel_tol))
127
+ candidates = np.flatnonzero(finite & (scores <= best + tol))
128
+ return int(candidates[np.argmax(alpha_grid[candidates])])
129
+
130
+
131
+ def _two_stage_candidate_mask(scores, refine_top_k=3):
132
+ """Return alpha candidates to strictly refine after approximate screening."""
133
+ scores = np.asarray(scores, dtype=np.float64).ravel()
134
+ n_scores = scores.size
135
+ mask = np.zeros(n_scores, dtype=bool)
136
+ finite = np.isfinite(scores)
137
+ if n_scores == 0:
138
+ return mask
139
+ if not np.any(finite):
140
+ warnings.warn("All approximate CV scores are NaN; refining all candidates.", stacklevel=2)
141
+ mask[:] = True
142
+ return mask
143
+
144
+ # Endpoint alphas are common optima on flat or monotone CV curves. Always
145
+ # refine them so approximate screening cannot drop boundary solutions.
146
+ mask[0] = True
147
+ mask[-1] = True
148
+
149
+ k = min(max(1, int(refine_top_k)), int(np.count_nonzero(finite)))
150
+ ranked = np.argsort(np.where(finite, scores, np.inf))[:k]
151
+ for idx in ranked:
152
+ lo = max(0, int(idx) - 1)
153
+ hi = min(n_scores, int(idx) + 2)
154
+ mask[lo:hi] = True
155
+
156
+ best = float(np.nanmin(scores))
157
+ near_tol = max(abs(best) * 0.005, 1e-6)
158
+ mask |= finite & (scores <= best + near_tol)
159
+ return mask
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # Per-sample loss function for squared_error (unique signature: needs X_design)
164
+ # ---------------------------------------------------------------------------
165
+ def _ps_squared_error(eta, y, X_design=None, coef_with_intercept=None, **_):
166
+ return (y - X_design @ coef_with_intercept) ** 2
167
+
168
+
169
+ # loss_name -> (per_sample_fn, uses_design)
170
+ # uses_design=True: fn needs X_design and coef_with_intercept (squared_error)
171
+ # uses_design=False: fn uses eta directly (all GLM losses)
172
+ # Populated below after the loss formula registry functions are defined.
173
+ _LOSS_EVAL_DISPATCH = {}
174
+
175
+
176
+ def _weighted_mean(per_sample, sw):
177
+ """Compute weighted or unweighted mean of per-sample values."""
178
+ if sw is not None:
179
+ w_sum = float(np.sum(sw))
180
+ if w_sum <= 0:
181
+ return float(np.mean(per_sample))
182
+ return float(np.dot(sw, per_sample) / w_sum)
183
+ return float(np.mean(per_sample))
184
+
185
+
186
+ def _evaluate_loss_numpy(loss_name, loss_fn, X_val_np, y_val_np, coef_np, intercept, fit_intercept, sample_weight=None):
187
+ """Backend-independent validation loss in float64 numpy.
188
+
189
+ When sample_weight is provided, returns weighted mean loss.
190
+ """
191
+ coef_np = np.asarray(coef_np, dtype=np.float64).ravel()
192
+ sw = np.asarray(sample_weight, dtype=np.float64).ravel() if sample_weight is not None else None
193
+
194
+ entry = _LOSS_EVAL_DISPATCH.get(loss_name)
195
+ if entry is not None:
196
+ # Resolve loss-specific parameters from loss_fn object
197
+ _loss_params = {}
198
+ if loss_name == "negative_binomial":
199
+ _loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
200
+ elif loss_name == "tweedie":
201
+ _loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
202
+
203
+ per_sample_fn, uses_design = entry
204
+ if uses_design:
205
+ n_val = X_val_np.shape[0]
206
+ if fit_intercept:
207
+ X_design = np.column_stack([np.ones(n_val), X_val_np])
208
+ coef_with_intercept = np.concatenate([[float(intercept)], coef_np])
209
+ else:
210
+ X_design = X_val_np
211
+ coef_with_intercept = coef_np
212
+ eta = X_val_np @ coef_np + (float(intercept) if fit_intercept else 0.0)
213
+ per_sample = per_sample_fn(eta, y_val_np, X_design=X_design, coef_with_intercept=coef_with_intercept, **_loss_params)
214
+ else:
215
+ eta = X_val_np @ coef_np + (float(intercept) if fit_intercept else 0.0)
216
+ per_sample = per_sample_fn(eta, y_val_np, **_loss_params)
217
+ return _weighted_mean(per_sample, sw)
218
+
219
+ # Fallback for unknown loss types
220
+ n_val = X_val_np.shape[0]
221
+ if fit_intercept:
222
+ X_design = np.column_stack([np.ones(n_val), X_val_np])
223
+ coef_with_intercept = np.concatenate([[float(intercept)], coef_np])
224
+ else:
225
+ X_design = X_val_np
226
+ coef_with_intercept = coef_np
227
+ # Fallback: unweighted loss. Weighted mean cannot be derived from
228
+ # unweighted mean, so weights are ignored for unknown loss types.
229
+ if sw is not None:
230
+ import warnings
231
+ warnings.warn(
232
+ f"_evaluate_loss_numpy: loss '{loss_name}' not in dispatch table, "
233
+ f"falling back to unweighted loss_fn.value(). Sample weights ignored.",
234
+ RuntimeWarning,
235
+ stacklevel=2,
236
+ )
237
+ return float(loss_fn.value(X_design, y_val_np, coef_with_intercept))
238
+
239
+
240
+ def _ridge_eig_batch(X_train_np, y_train_np, X_val_np, y_val_np, alphas_np):
241
+ """Batch Ridge solve via eigendecomposition on numpy.
242
+
243
+ Returns (mse_array, coefs_matrix, intercepts_array).
244
+ All computation in float64 numpy for maximum precision.
245
+ """
246
+ n, p = X_train_np.shape
247
+ n_alphas = len(alphas_np)
248
+
249
+ X_mean = np.mean(X_train_np, axis=0)
250
+ y_mean = np.mean(y_train_np)
251
+ Xc = X_train_np - X_mean
252
+ yc = y_train_np - y_mean
253
+
254
+ XtX = Xc.T @ Xc
255
+ eigvals, Q = np.linalg.eigh(XtX)
256
+ eigvals = np.maximum(eigvals, _EIGVAL_FLOOR)
257
+
258
+ QtXty = Q.T @ (Xc.T @ yc)
259
+ n_alpha = n * alphas_np
260
+ inv_diag = 1.0 / (eigvals[:, None] + n_alpha[None, :])
261
+ coefs = Q @ (inv_diag * QtXty[:, None])
262
+ intercepts = y_mean - X_mean @ coefs
263
+
264
+ # Predict: X_val @ coef + intercept (intercept already includes -X_mean @ coef)
265
+ y_pred = X_val_np @ coefs + intercepts[None, :]
266
+ mse = np.mean((y_val_np[:, None] - y_pred) ** 2, axis=0)
267
+
268
+ return mse, coefs, intercepts
269
+
270
+
271
+ def _ridge_eig_single(X_train_np, y_train_np, alpha, sample_weight=None):
272
+ """Single Ridge solve via eigendecomposition. Returns (coef, intercept).
273
+
274
+ When sample_weight is provided, uses weighted centering and weighted
275
+ normal equations: X'WX coef = X'Wy, solved via eigendecomposition of
276
+ X'WX. Same O(p³) cost as unweighted path.
277
+ """
278
+ n, p = X_train_np.shape
279
+ if sample_weight is not None:
280
+ w = np.asarray(sample_weight, dtype=np.float64).ravel()
281
+ w_sum = w.sum()
282
+ X_mean = np.average(X_train_np, axis=0, weights=w)
283
+ y_mean = float(np.average(y_train_np, weights=w))
284
+ Xc = X_train_np - X_mean
285
+ yc = y_train_np - y_mean
286
+ # Weighted normal equations: Xc' diag(w) Xc
287
+ W_sqrt_Xc = Xc * np.sqrt(w)[:, None]
288
+ XtWX = W_sqrt_Xc.T @ W_sqrt_Xc
289
+ XtWy = (Xc * w[:, None]).T @ yc
290
+ eigvals, Q = np.linalg.eigh(XtWX)
291
+ eigvals = np.maximum(eigvals, _EIGVAL_FLOOR)
292
+ QtXtWy = Q.T @ XtWy
293
+ inv_diag = 1.0 / (eigvals + w_sum * alpha)
294
+ coef = Q @ (inv_diag * QtXtWy)
295
+ intercept = float(y_mean - X_mean @ coef)
296
+ return coef, intercept
297
+ X_mean = np.mean(X_train_np, axis=0)
298
+ y_mean = np.mean(y_train_np)
299
+ Xc = X_train_np - X_mean
300
+ yc = y_train_np - y_mean
301
+
302
+ XtX = Xc.T @ Xc
303
+ eigvals, Q = np.linalg.eigh(XtX)
304
+ eigvals = np.maximum(eigvals, _EIGVAL_FLOOR)
305
+
306
+ QtXty = Q.T @ (Xc.T @ yc)
307
+ inv_diag = 1.0 / (eigvals + n * alpha)
308
+ coef = Q @ (inv_diag * QtXty)
309
+ intercept = float(y_mean - X_mean @ coef)
310
+ return coef, intercept
311
+
312
+
313
+ def _backend_name_for_cv_device(device):
314
+ name = _device_to_name(device)
315
+ if name == "cuda":
316
+ return "cupy"
317
+ if name == "torch":
318
+ return "torch"
319
+ return "numpy"
320
+
321
+
322
+ # Import shared utility from _cv_base
323
+ from statgpu.cross_validation._base import _torch_cuda_available
324
+
325
+
326
+ def _logistic_sparse_effective_max_iter(max_iter, device, penalty_name, refit=False):
327
+ backend = _backend_name_for_cv_device(device)
328
+ penalty_name = str(penalty_name).lower()
329
+ if backend in ("cupy", "torch") and not refit:
330
+ if penalty_name == "l1":
331
+ return min(int(max_iter), _FISTA_MAX_ITER_CV)
332
+ if penalty_name in ("elasticnet", "en"):
333
+ return min(int(max_iter), _FISTA_MAX_ITER_CV_SMALL)
334
+ return int(max_iter)
335
+
336
+
337
+ def _glm_cv_effective_max_iter(max_iter, loss_name, penalty_name, device, refit=False):
338
+ """CV-only iteration caps for GPU paths whose alpha ranking stabilizes early."""
339
+ backend = _backend_name_for_cv_device(device)
340
+ loss_name = str(loss_name).lower()
341
+ penalty_name = str(penalty_name).lower()
342
+ if backend in ("cupy", "torch") and not refit:
343
+ if loss_name == "tweedie" and penalty_name in ("l1", "elasticnet", "en"):
344
+ return min(int(max_iter), 200)
345
+ if backend == "cupy" and not refit:
346
+ if loss_name == "negative_binomial" and penalty_name == "l2":
347
+ return min(int(max_iter), 30)
348
+ return int(max_iter)
349
+
350
+
351
+ def _to_backend_float64(arr, backend):
352
+ if backend == "cupy":
353
+ import cupy as cp
354
+ return cp.asarray(arr, dtype=cp.float64)
355
+ if backend == "torch":
356
+ import torch
357
+ if isinstance(arr, torch.Tensor):
358
+ # Preserve existing device, just cast dtype
359
+ return arr.to(dtype=torch.float64)
360
+ # Numpy -> torch on current CUDA device
361
+ _dev = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
362
+ return torch.as_tensor(np.asarray(arr, dtype=np.float64), dtype=torch.float64, device=_dev)
363
+ return np.asarray(arr, dtype=np.float64)
364
+
365
+
366
+ # ---------------------------------------------------------------------------
367
+ # Unified fold-batched CV framework
368
+ # ---------------------------------------------------------------------------
369
+
370
+ def _fb_ones(shape, dtype, is_torch, device=None):
371
+ """Create ones tensor on the appropriate backend."""
372
+ if is_torch:
373
+ import torch
374
+ return torch.ones(shape, dtype=dtype, device=device)
375
+ import cupy as cp
376
+ return cp.ones(shape, dtype=dtype)
377
+
378
+
379
+ def _fb_zeros(shape, dtype, is_torch, device=None):
380
+ """Create zeros tensor on the appropriate backend."""
381
+ if is_torch:
382
+ import torch
383
+ return torch.zeros(shape, dtype=dtype, device=device)
384
+ import cupy as cp
385
+ return cp.zeros(shape, dtype=dtype)
386
+
387
+
388
+ def _fb_as_tensor(arr, is_torch, device=None):
389
+ """Convert numpy array to int64 backend tensor (for index arrays)."""
390
+ arr_i64 = np.asarray(arr, dtype=np.int64)
391
+ if is_torch:
392
+ import torch
393
+ return torch.as_tensor(arr_i64, dtype=torch.long, device=device)
394
+ import cupy as cp
395
+ return cp.asarray(arr_i64)
396
+
397
+
398
+ def _fb_copy(x, is_torch):
399
+ """Copy a backend tensor."""
400
+ return x.clone() if is_torch else x.copy()
401
+
402
+
403
+ def _fb_cat(tensors, is_torch, dim=1):
404
+ """Concatenate tensors along dim."""
405
+ if is_torch:
406
+ import torch
407
+ return torch.cat(tensors, dim=dim)
408
+ import cupy as cp
409
+ return cp.concatenate(tensors, axis=dim)
410
+
411
+
412
+ def _fb_sum(x, is_torch, axis=0, keepdims=False):
413
+ """Sum along axis."""
414
+ if is_torch:
415
+ return x.sum(dim=axis, keepdim=keepdims)
416
+ return x.sum(axis=axis, keepdims=keepdims)
417
+
418
+
419
+ def _fb_stack(arrays, is_torch, dim=1):
420
+ """Stack arrays along dim."""
421
+ if is_torch:
422
+ import torch
423
+ return torch.stack(arrays, dim=dim)
424
+ import cupy as cp
425
+ return cp.stack(arrays, axis=dim)
426
+
427
+
428
+ def _fold_batch_lipschitz_logistic(X_aug, y_train, n_train, is_torch):
429
+ eig_max = _max_eigval_power(X_aug.T @ X_aug)
430
+ return max(eig_max / (4.0 * max(n_train, 1)), 1e-12)
431
+
432
+
433
+ def _fold_batch_lipschitz_exp_link(X_aug, y_train, n_train, is_torch):
434
+ """Lipschitz for log-link GLMs (Poisson, Gamma, NB, InvGauss, Tweedie).
435
+ Uses y-scaling: max(1, y_mean, sqrt(y_mean * y_max))."""
436
+ eig_max = _max_eigval_power(X_aug.T @ X_aug)
437
+ if is_torch:
438
+ import torch
439
+ y_mean = float(y_train.mean().item())
440
+ y_max = float(y_train.max().item())
441
+ else:
442
+ import cupy as cp
443
+ y_mean = float(y_train.mean())
444
+ y_max = float(y_train.max())
445
+ y_scale = max(1.0, y_mean, np.sqrt(y_mean * max(y_max, 1e-10)))
446
+ return max(eig_max / max(n_train, 1), 1e-12) * y_scale
447
+
448
+
449
+ def _fold_batch_lipschitz_gamma(X_aug, y_train, n_train, is_torch):
450
+ """Lipschitz for Gamma log-link: eig_max(X'X)/n * max(y/y_mean).
451
+
452
+ Differs from _fold_batch_lipschitz_exp_link because Gamma's Hessian
453
+ weights are y/mu (not mu), so scaling uses y-ratio instead of y-moment.
454
+ """
455
+ eig_max = _max_eigval_power(X_aug.T @ X_aug)
456
+ if is_torch:
457
+ import torch
458
+ y_mean = float(y_train.mean().item())
459
+ y_ratio_max = float((y_train / y_mean).max().item()) if y_mean > 0 else 1.0
460
+ else:
461
+ import cupy as cp
462
+ y_mean = float(y_train.mean())
463
+ y_ratio_max = float((y_train / y_mean).max()) if y_mean > 0 else 1.0
464
+ return max(eig_max / max(n_train, 1), 1e-12) * max(1.0, y_ratio_max)
465
+
466
+
467
+ # Loss-specific configs: lipschitz_fn and intercept_fn only.
468
+ # ---------------------------------------------------------------------------
469
+ # Loss formula registry — single source of truth for residual and val_loss
470
+ # ---------------------------------------------------------------------------
471
+ # Each loss registers (residual_fn, val_loss_fn) that work with any backend
472
+ # (numpy/torch/cupy) via elementwise ops. The FISTA hot loop calls these
473
+ # instead of inline if/elif chains, eliminating formula duplication.
474
+ #
475
+ # Signature: fn(eta, y, **params) -> per_sample_loss_or_residual
476
+ # `eta` and `y` are backend arrays; `params` carries loss-specific scalars.
477
+
478
+ # Use backend-agnostic utilities from statgpu.backends._array_ops
479
+ # Must be imported before loss function definitions so _res_logistic etc.
480
+ # can use _sigmoid and _softplus.
481
+ from statgpu.backends._array_ops import (
482
+ _clip as _safe_clip,
483
+ _xp as _get_xp,
484
+ _sigmoid,
485
+ _softplus,
486
+ _abs_sum_dev,
487
+ _device_gt,
488
+ _max_eigval_power,
489
+ )
490
+
491
+
492
+ _LOSS_RESIDUAL_FNS = {}
493
+ _LOSS_VALLOSS_FNS = {}
494
+
495
+ def _register_loss_fns(loss_name, residual_fn, val_loss_fn):
496
+ """Register per-sample residual (gradient) and validation loss functions for a loss."""
497
+ _LOSS_RESIDUAL_FNS[loss_name] = residual_fn
498
+ _LOSS_VALLOSS_FNS[loss_name] = val_loss_fn
499
+
500
+ # --- Logistic ---
501
+ def _res_logistic(eta, y, **_):
502
+ # Gradient of logistic loss: sigmoid(eta) - y
503
+ return _sigmoid(eta) - y
504
+
505
+ def _val_logistic(eta, y, **_):
506
+ # Logistic loss: -y*eta + softplus(eta)
507
+ return -y * eta + _softplus(eta)
508
+
509
+ # --- Poisson ---
510
+ def _res_poisson(eta, y, **_):
511
+ # Gradient of Poisson loss: d/deta [mu - y*log(mu)] = mu - y
512
+ xp = _get_xp(eta)
513
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
514
+ mu_c = _safe_clip(mu, _MU_LO, None)
515
+ return mu_c - y
516
+
517
+ def _val_poisson(eta, y, **_):
518
+ xp = _get_xp(eta)
519
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
520
+ mu_c = _safe_clip(mu, _MU_LO, None)
521
+ return mu_c - y * xp.log(mu_c)
522
+
523
+ # --- Gamma ---
524
+ def _res_gamma(eta, y, **_):
525
+ xp = _get_xp(eta)
526
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
527
+ mu_c = _safe_clip(mu, _MU_LO, None)
528
+ return 1.0 - y / mu_c
529
+
530
+ def _val_gamma(eta, y, **_):
531
+ xp = _get_xp(eta)
532
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
533
+ mu_c = _safe_clip(mu, _MU_LO, None)
534
+ return y / mu_c + xp.log(mu_c)
535
+
536
+ # --- Inverse Gaussian ---
537
+ def _res_invgauss(eta, y, **_):
538
+ xp = _get_xp(eta)
539
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
540
+ # Clip mu^2 (not mu) to avoid denom as small as 1e-20 when mu ~ 1e-10
541
+ mu_sq_c = _safe_clip(mu * mu, _MU_LO, None)
542
+ return (mu - y) / mu_sq_c
543
+
544
+ def _val_invgauss(eta, y, **_):
545
+ # Inverse Gaussian loss: y/(2*mu^2) - 1/mu
546
+ xp = _get_xp(eta)
547
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
548
+ # Clip mu^2 (not mu) to match _ps_inverse_gaussian: denom >= 2e-10
549
+ mu_sq_c = _safe_clip(mu * mu, _MU_LO, None)
550
+ mu_c = _safe_clip(mu, _MU_LO, None)
551
+ return y / (2.0 * mu_sq_c) - 1.0 / mu_c
552
+
553
+ # --- Negative Binomial ---
554
+ def _res_nb(eta, y, alpha=_NB_ALPHA_DEFAULT, **_):
555
+ # Gradient of NB loss: d/deta L = (mu - y) / (1 + alpha*mu)
556
+ xp = _get_xp(eta)
557
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
558
+ mu_c = _safe_clip(mu, _MU_LO, None)
559
+ return (mu_c - y) / (1.0 + alpha * mu_c)
560
+
561
+ def _val_nb(eta, y, alpha=_NB_ALPHA_DEFAULT, **_):
562
+ xp = _get_xp(eta)
563
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
564
+ mu_c = _safe_clip(mu, _MU_LO, None)
565
+ one_plus = 1.0 + alpha * mu_c
566
+ return -y * xp.log(mu_c / one_plus) + (1.0 / alpha) * xp.log(one_plus)
567
+
568
+ # --- Tweedie ---
569
+ def _res_tweedie(eta, y, power=_TWEEDIE_POWER_DEFAULT, **_):
570
+ xp = _get_xp(eta)
571
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_TWEEDIE, _ETA_CLIP_TWEEDIE))
572
+ mu_c = _safe_clip(mu, _MU_LO_TWEEDIE, _MU_HI_TWEEDIE)
573
+ return xp.exp((1 - power) * xp.log(mu_c)) * (mu_c - y)
574
+
575
+ def _val_tweedie(eta, y, power=_TWEEDIE_POWER_DEFAULT, **_):
576
+ # Tweedie loss: -y*mu^(1-p)/(1-p) + mu^(2-p)/(2-p)
577
+ # Handle boundary: power=1 (Poisson) and power=2 (Gamma) use log form.
578
+ xp = _get_xp(eta)
579
+ mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_TWEEDIE, _ETA_CLIP_TWEEDIE))
580
+ mu_c = _safe_clip(mu, _MU_LO_TWEEDIE, _MU_HI_TWEEDIE)
581
+ log_mu = xp.log(mu_c)
582
+ d1 = 1.0 - power
583
+ d2 = 2.0 - power
584
+ term1 = -y * xp.exp(d1 * log_mu) / d1 if abs(d1) > 1e-10 else -y * log_mu
585
+ term2 = xp.exp(d2 * log_mu) / d2 if abs(d2) > 1e-10 else log_mu
586
+ return term1 + term2
587
+
588
+
589
+ _register_loss_fns("logistic", _res_logistic, _val_logistic)
590
+ _register_loss_fns("poisson", _res_poisson, _val_poisson)
591
+ _register_loss_fns("gamma", _res_gamma, _val_gamma)
592
+ _register_loss_fns("inverse_gaussian", _res_invgauss, _val_invgauss)
593
+ _register_loss_fns("negative_binomial", _res_nb, _val_nb)
594
+ _register_loss_fns("tweedie", _res_tweedie, _val_tweedie)
595
+
596
+
597
+ # Populate _LOSS_EVAL_DISPATCH using the backend-agnostic _val_* functions.
598
+ # _val_* auto-detect numpy/cupy/torch via _get_xp; _evaluate_loss_numpy
599
+ # always passes numpy arrays, so they behave identically to the old _ps_* fns.
600
+ _LOSS_EVAL_DISPATCH.update({
601
+ "logistic": (_val_logistic, False),
602
+ "squared_error": (_ps_squared_error, True),
603
+ "poisson": (_val_poisson, False),
604
+ "gamma": (_val_gamma, False),
605
+ "inverse_gaussian": (_val_invgauss, False),
606
+ "negative_binomial": (_val_nb, False),
607
+ "tweedie": (_val_tweedie, False),
608
+ })
609
+
610
+
611
+ # Fold-batch config: Lipschitz function and intercept function per loss.
612
+ # Residual and val_loss are handled by the loss formula registry above.
613
+ _FOLD_BATCH_CONFIGS = {}
614
+
615
+
616
+ def _logistic_intercept(y_mean, is_torch):
617
+ if is_torch:
618
+ import torch
619
+ y_prob = torch.clamp(y_mean, min=1e-3, max=0.999)
620
+ return torch.log(y_prob) - torch.log(1.0 - y_prob)
621
+ else:
622
+ import cupy as cp
623
+ y_prob = cp.clip(y_mean, 1e-3, 0.999)
624
+ return cp.log(y_prob) - cp.log(1.0 - y_prob)
625
+
626
+
627
+ def _exp_link_intercept(y_mean, is_torch):
628
+ """Intercept for log-link GLMs: log(clamp(y_mean, 1e-3, 100))."""
629
+ if is_torch:
630
+ import torch
631
+ return torch.log(torch.clamp(y_mean, min=1e-3, max=100.0))
632
+ else:
633
+ import cupy as cp
634
+ return cp.log(cp.clip(y_mean, 1e-3, 100.0))
635
+
636
+
637
+ def _register_fold_batch(loss_name, lipschitz_fn, intercept_fn):
638
+ _FOLD_BATCH_CONFIGS[loss_name] = {
639
+ "lipschitz_fn": lipschitz_fn,
640
+ "intercept_fn": intercept_fn,
641
+ }
642
+
643
+
644
+ _register_fold_batch("logistic", _fold_batch_lipschitz_logistic, _logistic_intercept)
645
+ _register_fold_batch("poisson", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
646
+ _register_fold_batch("gamma", _fold_batch_lipschitz_gamma, _exp_link_intercept)
647
+ _register_fold_batch("inverse_gaussian", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
648
+ _register_fold_batch("negative_binomial", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
649
+ _register_fold_batch("tweedie", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
650
+
651
+
652
+ def _glm_sparse_cv_folds(
653
+ X,
654
+ y,
655
+ folds,
656
+ alpha_sorted,
657
+ penalty_name,
658
+ l1_ratio,
659
+ max_iter,
660
+ tol,
661
+ loss_name,
662
+ device_backend,
663
+ sample_weight=None,
664
+ loss_kwargs=None,
665
+ ):
666
+ """Unified fold-batched sparse GLM CV path for all losses and backends.
667
+
668
+ Uses direct torch/cupy API calls (no abstraction layer) for performance
669
+ in the FISTA hot loop.
670
+ """
671
+ cfg = _FOLD_BATCH_CONFIGS.get(loss_name)
672
+ if cfg is None:
673
+ return None
674
+
675
+ # Resolve loss-specific parameters: user-specified kwargs override defaults
676
+ _lk = loss_kwargs or {}
677
+ from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
678
+ _loss_obj = _resolve_loss_name(loss_name, loss_kwargs=_lk)
679
+ _nb_alpha = float(_lk.get('alpha', getattr(_loss_obj, 'alpha', _NB_ALPHA_DEFAULT)))
680
+ _tw_power = float(_lk.get('power', getattr(_loss_obj, 'power', _TWEEDIE_POWER_DEFAULT)))
681
+
682
+ is_torch = (device_backend == "torch")
683
+ if is_torch:
684
+ if _backend_name_for_cv_device("torch") != "torch":
685
+ return None
686
+ try:
687
+ import torch
688
+ if not torch.cuda.is_available():
689
+ return None
690
+ except (ImportError, RuntimeError, OSError):
691
+ return None
692
+ else:
693
+ if _backend_name_for_cv_device("cuda") != "cupy":
694
+ return None
695
+ try:
696
+ import cupy as cp
697
+ if cp.cuda.runtime.getDeviceCount() <= 0:
698
+ return None
699
+ except (ImportError, RuntimeError, OSError):
700
+ return None
701
+
702
+ Xb = _to_backend_float64(X, device_backend)
703
+ yb = _to_backend_float64(y, device_backend).reshape(-1)
704
+ alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
705
+ penalty_name = str(penalty_name).lower()
706
+ is_enet = penalty_name in ("elasticnet", "en")
707
+ n_samples, n_features = Xb.shape
708
+ n_folds = len(folds)
709
+ if n_folds < 2 or alphas.size == 0:
710
+ return None
711
+
712
+ lipschitz_fn = cfg["lipschitz_fn"]
713
+ intercept_fn = cfg["intercept_fn"]
714
+
715
+ # --- Build masks and compute per-fold Lipschitz ---
716
+ dev = Xb.device if is_torch else None
717
+ train_mask = _fb_ones((n_samples, n_folds), Xb.dtype, is_torch, dev)
718
+ val_mask = _fb_zeros((n_samples, n_folds), Xb.dtype, is_torch, dev)
719
+
720
+ # Sample weight mask: per-fold weights (n_samples, n_folds)
721
+ # Deferred until after the fold loop when train_mask is finalized.
722
+ has_weights = sample_weight is not None
723
+ if has_weights:
724
+ sw_all = _to_backend_float64(sample_weight, device_backend).reshape(-1)
725
+
726
+ step_values = []
727
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
728
+ train_idx_dev = _fb_as_tensor(train_idx, is_torch, dev)
729
+ val_idx_dev = _fb_as_tensor(val_idx, is_torch, dev)
730
+ train_mask[val_idx_dev, fold_idx] = 0.0
731
+ val_mask[val_idx_dev, fold_idx] = 1.0
732
+
733
+ X_train = Xb[train_idx_dev]
734
+ y_train = yb[train_idx_dev]
735
+ ones = _fb_ones((X_train.shape[0], 1), Xb.dtype, is_torch, dev)
736
+ X_aug = _fb_cat([X_train, ones], is_torch)
737
+ n_train = int(X_train.shape[0])
738
+ # For weighted Lipschitz, pass sum(w) as effective n for normalization
739
+ if has_weights:
740
+ sw_fold = sw_all[train_idx_dev]
741
+ sw_col_fold = sw_fold.reshape(-1, 1)
742
+ _xp_sw = _get_xp(sw_col_fold)
743
+ Xw = X_aug * _xp_sw.sqrt(sw_col_fold)
744
+ w_sum_fold = float(sw_fold.sum().item()) if is_torch else float(sw_fold.sum())
745
+ L_loss = lipschitz_fn(Xw, y_train, max(w_sum_fold, 1.0), is_torch)
746
+ else:
747
+ L_loss = lipschitz_fn(X_aug, y_train, n_train, is_torch)
748
+ step_values.append(1.0 / L_loss)
749
+
750
+ # Build sw_mask now that train_mask is finalized (val rows are 0)
751
+ if has_weights:
752
+ sw_mask = sw_all.reshape(-1, 1) * train_mask
753
+ else:
754
+ sw_mask = train_mask # effectively all 1s for train, 0s for val
755
+
756
+ # --- Initialize parameters ---
757
+ sw_train_vec = _fb_sum(sw_mask, is_torch, axis=0, keepdims=True).reshape(1, n_folds)
758
+ # Guard against zero-weight folds (would cause division-by-zero)
759
+ if is_torch:
760
+ sw_train_vec = torch.clamp(sw_train_vec, min=1e-10)
761
+ else:
762
+ sw_train_vec = cp.clip(sw_train_vec, 1e-10, None)
763
+ n_val_vec = _fb_sum(val_mask, is_torch, axis=0, keepdims=True).reshape(1, n_folds)
764
+ # Guard against zero-sample validation folds (division-by-zero)
765
+ if is_torch:
766
+ n_val_vec = torch.clamp(n_val_vec, min=1.0)
767
+ else:
768
+ n_val_vec = cp.maximum(n_val_vec, 1.0)
769
+ y_col = yb.reshape(-1, 1)
770
+ # Weighted mean of y per fold
771
+ y_mean = _fb_sum(y_col * sw_mask, is_torch, axis=0, keepdims=True) / sw_train_vec
772
+ intercept = intercept_fn(y_mean, is_torch).reshape(1, n_folds)
773
+ coef = _fb_zeros((n_features, n_folds), Xb.dtype, is_torch, dev)
774
+ if is_torch:
775
+ import torch
776
+ step = torch.as_tensor(step_values, dtype=Xb.dtype, device=dev).reshape(1, n_folds)
777
+ else:
778
+ import cupy as cp
779
+ step = cp.asarray(step_values, dtype=Xb.dtype).reshape(1, n_folds)
780
+
781
+ tol_float = float(tol)
782
+ scores_path = []
783
+ iters_path = []
784
+
785
+ # Pre-build loss kwargs to avoid dict construction in hot loop
786
+ _loss_kwargs = {}
787
+ if loss_name == "negative_binomial":
788
+ _loss_kwargs["alpha"] = _nb_alpha
789
+ elif loss_name == "tweedie":
790
+ _loss_kwargs["power"] = _tw_power
791
+
792
+ # Hoist function lookups outside hot loop (avoid dict lookup per iteration)
793
+ _resid_fn = _LOSS_RESIDUAL_FNS[loss_name]
794
+ _valloss_fn = _LOSS_VALLOSS_FNS[loss_name]
795
+
796
+ # Precompute sw_val_mask/sw_val_vec once (val_mask is constant across alphas)
797
+ if has_weights:
798
+ sw_val_mask = sw_all.reshape(-1, 1) * val_mask
799
+ sw_val_vec = _fb_sum(sw_val_mask, is_torch, axis=0, keepdims=True).reshape(1, n_folds)
800
+ sw_val_vec = torch.clamp(sw_val_vec, min=1e-10) if is_torch else cp.clip(sw_val_vec, 1e-10, None)
801
+
802
+ # --- FISTA loop over alphas ---
803
+ # y_coef / y_intercept are the extrapolated iterates (standard FISTA notation).
804
+ for alpha in alphas:
805
+ y_coef = _fb_copy(coef, is_torch)
806
+ y_intercept = _fb_copy(intercept, is_torch)
807
+ t_k = 1.0
808
+ if is_torch:
809
+ active = torch.ones((1, n_folds), dtype=torch.bool, device=Xb.device)
810
+ last_iter = torch.zeros((n_folds,), dtype=torch.int64, device=Xb.device)
811
+ else:
812
+ active = cp.ones((1, n_folds), dtype=bool)
813
+ last_iter = cp.zeros((n_folds,), dtype=cp.int64)
814
+
815
+ for iteration in range(int(max_iter)):
816
+ coef_old = _fb_copy(coef, is_torch)
817
+ intercept_old = _fb_copy(intercept, is_torch)
818
+
819
+ eta = Xb @ y_coef + y_intercept
820
+ # Compute per-sample residual via loss registry.
821
+ # Each loss defines a backend-agnostic residual function.
822
+ resid = _resid_fn(eta, y_col, **_loss_kwargs) * train_mask
823
+ # Weighted gradient: multiply residual by sw_mask (includes train_mask)
824
+ # and divide by sum of weights per fold
825
+ grad_coef = (Xb.T @ (resid * sw_mask)) / sw_train_vec
826
+ grad_intercept = _fb_sum(resid * sw_mask, is_torch, axis=0, keepdims=True) / sw_train_vec
827
+
828
+ w = y_coef - step * grad_coef
829
+ if is_enet:
830
+ thresh = float(alpha) * float(l1_ratio) * step
831
+ denom = 1.0 + float(alpha) * (1.0 - float(l1_ratio)) * step
832
+ else:
833
+ thresh = float(alpha) * step
834
+ denom = 1.0
835
+ coef_new = _soft_threshold(w, thresh) / denom
836
+ intercept_new = y_intercept - step * grad_intercept
837
+
838
+ coef = torch.where(active, coef_new, coef) if is_torch else cp.where(active, coef_new, coef)
839
+ intercept = torch.where(active, intercept_new, intercept) if is_torch else cp.where(active, intercept_new, intercept)
840
+
841
+ beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
842
+ y_coef_new = coef + beta * (coef - coef_old)
843
+ y_intercept_new = intercept + beta * (intercept - intercept_old)
844
+ y_coef = torch.where(active, y_coef_new, coef) if is_torch else cp.where(active, y_coef_new, coef)
845
+ y_intercept = torch.where(active, y_intercept_new, intercept) if is_torch else cp.where(active, y_intercept_new, intercept)
846
+ if is_torch:
847
+ last_iter = torch.where(active.reshape(-1), torch.full_like(last_iter, iteration + 1), last_iter)
848
+ else:
849
+ last_iter = cp.where(active.reshape(-1), cp.full_like(last_iter, iteration + 1), last_iter)
850
+
851
+ # Check convergence: every iteration for first 20, then every 50
852
+ if iteration < 20 or iteration % 50 == 0:
853
+ if is_torch:
854
+ delta = torch.sum(torch.abs(coef - coef_old), dim=0, keepdim=True) + torch.abs(intercept - intercept_old)
855
+ else:
856
+ delta = cp.sum(cp.abs(coef - coef_old), axis=0, keepdims=True) + cp.abs(intercept - intercept_old)
857
+ active = active & (delta >= tol_float)
858
+ _any_active = torch.any(active) if is_torch else cp.any(active)
859
+ if not _to_float_scalar(_any_active):
860
+ break
861
+
862
+ # Validation loss via loss registry (single call, backend-agnostic)
863
+ eta_val = Xb @ coef + intercept
864
+ val_loss = _valloss_fn(eta_val, y_col, **_loss_kwargs) * val_mask
865
+ if has_weights:
866
+ scores_path.append(_fb_sum(val_loss * sw_val_mask, is_torch, axis=0, keepdims=True).reshape(-1) / sw_val_vec.reshape(-1))
867
+ else:
868
+ scores_path.append(_fb_sum(val_loss, is_torch, axis=0, keepdims=True).reshape(-1) / n_val_vec.reshape(-1))
869
+ iters_path.append(last_iter)
870
+
871
+ scores = _fb_stack(scores_path, is_torch)
872
+ n_iter = _fb_stack(iters_path, is_torch)
873
+ return {
874
+ "scores": np.asarray(_to_numpy(scores), dtype=np.float64),
875
+ "n_iter": np.asarray(_to_numpy(n_iter), dtype=np.int64),
876
+ }
877
+
878
+
879
+ def _scalar_to_float(x):
880
+ return float(_to_numpy(x))
881
+
882
+
883
+ def _logistic_sparse_cv_path(
884
+ X_train,
885
+ y_train,
886
+ alpha_sorted,
887
+ penalty_name,
888
+ l1_ratio,
889
+ max_iter,
890
+ tol,
891
+ device,
892
+ X_val=None,
893
+ y_val=None,
894
+ sample_weight=None,
895
+ val_sample_weight=None,
896
+ return_path=True,
897
+ ):
898
+ """Fit a logistic sparse alpha path and optionally score validation loss.
899
+
900
+ This CV-only path uses a fixed global Lipschitz bound and sparse proximal
901
+ updates, avoiding per-iteration Armijo/objective synchronizations.
902
+
903
+ Parameters
904
+ ----------
905
+ val_sample_weight : array-like, optional
906
+ Per-sample weights for validation scoring. When provided, validation
907
+ loss is computed as weighted mean.
908
+ """
909
+ if not _is_uniform_weight(sample_weight):
910
+ warnings.warn(
911
+ "_logistic_sparse_cv_path: non-uniform sample_weight not supported, "
912
+ "falling back to general CV path.",
913
+ RuntimeWarning,
914
+ stacklevel=2,
915
+ )
916
+ return None
917
+
918
+ backend = _backend_name_for_cv_device(device)
919
+ Xb = _to_backend_float64(X_train, backend)
920
+ yb = _to_backend_float64(y_train, backend).reshape(-1)
921
+ alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
922
+ n_samples, n_features = Xb.shape
923
+
924
+ from statgpu.backends._utils import _get_xp, xp_ones
925
+ xp = _get_xp(backend)
926
+ ones = xp_ones((n_samples, 1), dtype=Xb.dtype, xp=xp, ref_arr=Xb)
927
+ X_aug = xp.concatenate([Xb, ones], axis=1)
928
+ y_mean = _to_float_scalar(xp.mean(yb))
929
+ coef = _zeros(n_features, backend, ref_tensor=Xb)
930
+ _int_val = np.log(np.clip(y_mean, 1e-3, 1.0 - 1e-3) / (1.0 - np.clip(y_mean, 1e-3, 1.0 - 1e-3)))
931
+ from statgpu.backends._array_ops import _scalar_tensor
932
+ intercept = _scalar_tensor(_int_val, Xb)
933
+
934
+ eig_max = _max_eigval_power(X_aug.T @ X_aug)
935
+ L_loss = max(eig_max / (4.0 * max(int(n_samples), 1)), 1e-12)
936
+ step = 1.0 / L_loss
937
+ conv_interval = _CONV_INTERVAL_CV_NUMPY if backend == "numpy" else _CONV_INTERVAL_CV_FOLD
938
+ penalty_name = str(penalty_name).lower()
939
+ is_enet = penalty_name in ("elasticnet", "en")
940
+
941
+ if X_val is not None and y_val is not None:
942
+ Xv = _to_backend_float64(X_val, backend)
943
+ yv = _to_backend_float64(y_val, backend).reshape(-1)
944
+ swv = _to_backend_float64(val_sample_weight, backend).reshape(-1) if val_sample_weight is not None else None
945
+ else:
946
+ Xv = yv = swv = None
947
+
948
+ scores = []
949
+ score_coef_path = []
950
+ score_intercept_path = []
951
+ coef_path = []
952
+ intercept_path = []
953
+ iters = []
954
+
955
+ for alpha in alphas:
956
+ y_coef = _copy_arr(coef)
957
+ y_intercept = _copy_arr(intercept) if hasattr(intercept, 'clone') else float(intercept)
958
+ t_k = 1.0
959
+ last_iter = 0
960
+ for iteration in range(int(max_iter)):
961
+ coef_old = _copy_arr(coef)
962
+ intercept_old = _copy_arr(intercept) if hasattr(intercept, 'clone') else float(intercept)
963
+
964
+ eta = Xb @ y_coef + y_intercept
965
+ prob = _sigmoid(eta)
966
+ resid = prob - yb
967
+ grad_coef = Xb.T @ resid / n_samples
968
+ grad_intercept = xp.mean(resid)
969
+
970
+ w = y_coef - step * grad_coef
971
+ if is_enet:
972
+ thresh = float(alpha) * float(l1_ratio) * step
973
+ denom = 1.0 + float(alpha) * (1.0 - float(l1_ratio)) * step
974
+ else:
975
+ thresh = float(alpha) * step
976
+ denom = 1.0
977
+
978
+ coef = _soft_threshold(w, thresh) / denom
979
+ intercept = y_intercept - step * grad_intercept
980
+
981
+ beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
982
+ y_coef = coef + beta * (coef - coef_old)
983
+ y_intercept = intercept + beta * (intercept - intercept_old)
984
+ last_iter = iteration + 1
985
+
986
+ if iteration < 20 or iteration % conv_interval == 0:
987
+ delta = xp.sum(xp.abs(coef - coef_old)) + xp.abs(intercept - intercept_old)
988
+ converged = _to_float_scalar(delta) < tol
989
+ if converged:
990
+ break
991
+
992
+ if Xv is not None:
993
+ if backend == "torch":
994
+ score_coef_path.append(coef.clone())
995
+ score_intercept_path.append(intercept.clone())
996
+ else:
997
+ eta_v = Xv @ coef + intercept
998
+ per_sample = -yv * eta_v + _softplus(eta_v)
999
+ if swv is not None:
1000
+ sw_sum = xp.sum(swv)
1001
+ val_loss = xp.sum(swv * per_sample) / sw_sum if float(sw_sum) > 0 else xp.mean(per_sample)
1002
+ else:
1003
+ val_loss = xp.mean(per_sample)
1004
+ score_coef_path.append(val_loss)
1005
+ if return_path:
1006
+ coef_path.append(np.asarray(_to_numpy(coef), dtype=np.float64).copy())
1007
+ intercept_path.append(_scalar_to_float(intercept))
1008
+ iters.append(last_iter)
1009
+
1010
+ # Torch benefits from one alpha-path GEMM for validation. For NumPy/CuPy
1011
+ # at these small alpha-grid widths, per-alpha GEMV is consistently steadier.
1012
+ if score_coef_path and Xv is not None:
1013
+ if backend == "torch":
1014
+ import torch
1015
+ coef_mat = torch.stack(score_coef_path, dim=1)
1016
+ intercept_vec = torch.stack(score_intercept_path).reshape(1, -1)
1017
+ eta_v = Xv @ coef_mat + intercept_vec
1018
+ per_sample = -yv.reshape(-1, 1) * eta_v + _softplus(eta_v)
1019
+ if swv is not None:
1020
+ sw_sum = swv.sum()
1021
+ if sw_sum > 0:
1022
+ scores_tensor = (swv.reshape(-1, 1) * per_sample).sum(dim=0) / sw_sum
1023
+ else:
1024
+ scores_tensor = per_sample.mean(dim=0)
1025
+ else:
1026
+ scores_tensor = per_sample.mean(dim=0)
1027
+ scores = _to_numpy(scores_tensor).tolist()
1028
+ else:
1029
+ # cupy/numpy path accumulates Python floats (scalar val_loss)
1030
+ scores = [float(s) for s in score_coef_path]
1031
+
1032
+ out = {
1033
+ "scores": np.asarray(scores, dtype=np.float64) if scores else None,
1034
+ "n_iter": np.asarray(iters, dtype=np.int64),
1035
+ }
1036
+ if return_path:
1037
+ out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
1038
+ out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
1039
+ return out
1040
+
1041
+
1042
+ # (Old per-loss fold-batched functions removed — replaced by _glm_sparse_cv_folds)
1043
+
1044
+
1045
+ def _squared_error_sparse_cv_path(
1046
+ X_train,
1047
+ y_train,
1048
+ alpha_sorted,
1049
+ penalty_name,
1050
+ l1_ratio,
1051
+ max_iter,
1052
+ tol,
1053
+ device,
1054
+ X_val=None,
1055
+ y_val=None,
1056
+ sample_weight=None,
1057
+ val_sample_weight=None,
1058
+ return_path=True,
1059
+ ):
1060
+ """Fit a squared-error sparse alpha path with centered data.
1061
+
1062
+ This is used by CV for l1/elasticnet penalties. It solves all alphas in one
1063
+ fold using a single Gram matrix and warm-started FISTA path.
1064
+ """
1065
+ if not _is_uniform_weight(sample_weight):
1066
+ warnings.warn(
1067
+ "_squared_error_sparse_cv_path: non-uniform sample_weight not supported, "
1068
+ "falling back to general CV path.",
1069
+ RuntimeWarning,
1070
+ stacklevel=2,
1071
+ )
1072
+ return None
1073
+
1074
+ backend = _backend_name_for_cv_device(device)
1075
+ Xb = _to_backend_float64(X_train, backend)
1076
+ yb = _to_backend_float64(y_train, backend).reshape(-1)
1077
+ alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
1078
+ n_samples, n_features = Xb.shape
1079
+ penalty_name = str(penalty_name).lower()
1080
+ is_enet = penalty_name in ("elasticnet", "en")
1081
+
1082
+ from statgpu.backends._utils import _get_xp
1083
+ xp = _get_xp(backend)
1084
+ X_mean = xp.mean(Xb, axis=0)
1085
+ y_mean = xp.mean(yb)
1086
+ Xc = Xb - X_mean
1087
+ yc = yb - y_mean
1088
+ XtX = Xc.T @ Xc
1089
+ Xty = Xc.T @ yc
1090
+ coef = _zeros(n_features, backend, ref_tensor=Xb)
1091
+
1092
+ eig_max = _max_eigval_power(XtX)
1093
+ L = max(eig_max / max(int(n_samples), 1), 1e-12)
1094
+ step = 1.0 / L
1095
+ conv_interval = _CONV_INTERVAL_CV_NUMPY if backend == "numpy" else _CONV_INTERVAL_CV_PATH
1096
+
1097
+ if X_val is not None and y_val is not None:
1098
+ Xv = _to_backend_float64(X_val, backend)
1099
+ yv = _to_backend_float64(y_val, backend).reshape(-1)
1100
+ Xv_centered = Xv - X_mean
1101
+ swv = _to_backend_float64(val_sample_weight, backend).reshape(-1) if val_sample_weight is not None else None
1102
+ else:
1103
+ Xv = yv = Xv_centered = swv = None
1104
+
1105
+ if backend in ("torch", "cupy") and not return_path and Xv_centered is not None:
1106
+ n_alpha = int(alphas.size)
1107
+ from statgpu.backends._utils import xp_asarray
1108
+ alpha_vec = xp_asarray(alphas, dtype=Xb.dtype, xp=xp, ref_arr=Xb).reshape(1, -1)
1109
+ coef_mat = _xp_zeros((n_features, n_alpha), Xb.dtype, Xb)
1110
+ y_mat = _copy_arr(coef_mat)
1111
+
1112
+ t_k = 1.0
1113
+ last_iter = 0
1114
+ x_ty = Xty.reshape(-1, 1)
1115
+ for iteration in range(int(max_iter)):
1116
+ coef_old = _copy_arr(coef_mat)
1117
+ grad = (XtX @ y_mat - x_ty) / n_samples
1118
+ w = y_mat - step * grad
1119
+ if is_enet:
1120
+ thresh = alpha_vec * float(l1_ratio) * step
1121
+ denom = 1.0 + alpha_vec * (1.0 - float(l1_ratio)) * step
1122
+ else:
1123
+ thresh = alpha_vec * step
1124
+ denom = 1.0
1125
+
1126
+ coef_mat = _soft_threshold(w, thresh) / denom
1127
+
1128
+ beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
1129
+ y_mat = coef_mat + beta * (coef_mat - coef_old)
1130
+ last_iter = iteration + 1
1131
+
1132
+ if iteration < 20 or iteration % conv_interval == 0:
1133
+ delta = xp.sum(xp.abs(coef_mat - coef_old), axis=0)
1134
+ if _to_float_scalar(xp.all(delta < tol)):
1135
+ break
1136
+
1137
+ pred = Xv_centered @ coef_mat + y_mean
1138
+ sq_err = (yv.reshape(-1, 1) - pred) ** 2
1139
+ if swv is not None:
1140
+ sw_col = swv.reshape(-1, 1)
1141
+ if backend == "torch":
1142
+ scores_dev = (sw_col * sq_err).sum(dim=0) / swv.sum()
1143
+ else:
1144
+ scores_dev = (sw_col * sq_err).sum(axis=0) / swv.sum()
1145
+ else:
1146
+ if backend == "torch":
1147
+ scores_dev = sq_err.mean(dim=0)
1148
+ else:
1149
+ scores_dev = sq_err.mean(axis=0)
1150
+ return {
1151
+ "scores": np.asarray(_to_numpy(scores_dev), dtype=np.float64),
1152
+ "n_iter": np.full(n_alpha, int(last_iter), dtype=np.int64),
1153
+ }
1154
+
1155
+ scores = []
1156
+ scores_dev = [] # accumulate on device, sync once at end
1157
+ coef_path = []
1158
+ intercept_path = []
1159
+ iters = []
1160
+
1161
+ for alpha in alphas:
1162
+ y_k = _copy_arr(coef)
1163
+ t_k = 1.0
1164
+ last_iter = 0
1165
+ for iteration in range(int(max_iter)):
1166
+ coef_old = _copy_arr(coef)
1167
+ grad = (XtX @ y_k - Xty) / n_samples
1168
+ w = y_k - step * grad
1169
+ if is_enet:
1170
+ thresh = float(alpha) * float(l1_ratio) * step
1171
+ denom = 1.0 + float(alpha) * (1.0 - float(l1_ratio)) * step
1172
+ else:
1173
+ thresh = float(alpha) * step
1174
+ denom = 1.0
1175
+
1176
+ coef = _soft_threshold(w, thresh) / denom
1177
+
1178
+ beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
1179
+ y_k = coef + beta * (coef - coef_old)
1180
+ last_iter = iteration + 1
1181
+
1182
+ if backend == "numpy" or int(n_features) <= 128:
1183
+ check_convergence = iteration < 20 or iteration % conv_interval == 0
1184
+ else:
1185
+ check_convergence = iteration % conv_interval == 0
1186
+ if check_convergence:
1187
+ delta = xp.sum(xp.abs(coef - coef_old))
1188
+ if _to_float_scalar(delta) < tol:
1189
+ break
1190
+
1191
+ intercept = y_mean - X_mean @ coef
1192
+ if Xv_centered is not None:
1193
+ pred = Xv_centered @ coef + y_mean
1194
+ sq_err = (yv - pred) ** 2
1195
+ if swv is not None:
1196
+ mse = xp.sum(swv * sq_err) / xp.sum(swv)
1197
+ else:
1198
+ mse = xp.mean(sq_err)
1199
+ scores_dev.append(mse) # keep on device
1200
+ if return_path:
1201
+ coef_path.append(np.asarray(_to_numpy(coef), dtype=np.float64).copy())
1202
+ intercept_path.append(_scalar_to_float(intercept))
1203
+ iters.append(last_iter)
1204
+
1205
+ # Batch sync validation scores from device.
1206
+ if scores_dev:
1207
+ if backend == "torch":
1208
+ import torch
1209
+ scores_tensor = torch.stack(scores_dev)
1210
+ scores = _to_numpy(scores_tensor).tolist()
1211
+ elif backend == "cupy":
1212
+ import cupy as cp
1213
+ scores_arr = cp.stack(scores_dev)
1214
+ scores = _to_numpy(scores_arr).tolist()
1215
+ else:
1216
+ scores = [float(s) for s in scores_dev]
1217
+
1218
+ out = {
1219
+ "scores": np.asarray(scores, dtype=np.float64) if scores else None,
1220
+ "n_iter": np.asarray(iters, dtype=np.int64),
1221
+ }
1222
+ if return_path:
1223
+ out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
1224
+ out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
1225
+ return out
1226
+
1227
+
1228
+ # Intercept clipping bound: exp(15) ≈ 3.3M, prevents overflow in link
1229
+ # functions while allowing a wide range of intercept values.
1230
+ from statgpu.cross_validation._base import INTERCEPT_CLIP_BOUND as _INTERCEPT_CLIP_BOUND
1231
+
1232
+
1233
+ class _FeatureOnlySparsePenalty:
1234
+ """Wrap a sparse penalty so the final intercept coefficient is unpenalized."""
1235
+
1236
+ def __init__(self, base_penalty, n_features, backend):
1237
+ self.base_penalty = base_penalty
1238
+ self.n_features = int(n_features)
1239
+ self.backend = backend
1240
+
1241
+ @property
1242
+ def name(self):
1243
+ return getattr(self.base_penalty, "name", "")
1244
+
1245
+ @property
1246
+ def alpha(self):
1247
+ return float(getattr(self.base_penalty, "alpha", 0.0))
1248
+
1249
+ @property
1250
+ def l1_ratio(self):
1251
+ return float(getattr(self.base_penalty, "l1_ratio", 1.0))
1252
+
1253
+ def value(self, coef):
1254
+ return self.base_penalty.value(coef[: self.n_features])
1255
+
1256
+ def proximal(self, w, step, backend=None):
1257
+ from statgpu.backends._array_ops import _xp, _clip, _xp_zeros
1258
+ backend = backend or self.backend
1259
+ xp = _xp(w)
1260
+ w_feat = w[: self.n_features]
1261
+ result_feat = self.base_penalty.proximal(w_feat, step, backend=backend)
1262
+ result = _xp_zeros(w.shape, w.dtype, w)
1263
+ result[: self.n_features] = result_feat
1264
+ result[self.n_features] = _clip(w[self.n_features], -_INTERCEPT_CLIP_BOUND, _INTERCEPT_CLIP_BOUND)
1265
+ return result
1266
+
1267
+
1268
+ def _glm_sparse_cv_path(
1269
+ loss_name,
1270
+ X_train,
1271
+ y_train,
1272
+ alpha_sorted,
1273
+ penalty_name,
1274
+ l1_ratio,
1275
+ max_iter,
1276
+ tol,
1277
+ device,
1278
+ X_val=None,
1279
+ y_val=None,
1280
+ sample_weight=None,
1281
+ val_sample_weight=None,
1282
+ return_path=False,
1283
+ solver_name="fista",
1284
+ cv_mode=True,
1285
+ loss_kwargs=None,
1286
+ ):
1287
+ """Warm-started sparse GLM alpha path for CV.
1288
+
1289
+ The helper is intentionally private: it reuses the production loss,
1290
+ penalty, and FISTA solver while avoiding estimator reconstruction and
1291
+ repeated host/device conversions inside a fold.
1292
+
1293
+ When ``val_sample_weight`` is provided, validation loss is computed as
1294
+ a weighted mean instead of a simple mean.
1295
+ """
1296
+ loss_name = str(loss_name).lower()
1297
+ penalty_name = str(penalty_name).lower()
1298
+ # Allow any loss registered in the formula registry
1299
+ if loss_name not in _LOSS_RESIDUAL_FNS:
1300
+ return None
1301
+ if penalty_name not in ("l1", "elasticnet", "en"):
1302
+ return None
1303
+ if not _is_uniform_weight(sample_weight):
1304
+ warnings.warn(
1305
+ "_glm_sparse_cv_path: non-uniform sample_weight not supported, "
1306
+ "falling back to general CV path.",
1307
+ RuntimeWarning,
1308
+ stacklevel=2,
1309
+ )
1310
+ return None
1311
+
1312
+ from statgpu.solvers import fista_solver, fista_bb_solver
1313
+ from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
1314
+ from statgpu.penalties import get_penalty
1315
+
1316
+ backend = _backend_name_for_cv_device(device)
1317
+ from statgpu.backends._utils import _get_xp
1318
+ xp = _get_xp(backend)
1319
+ Xb = _to_backend_float64(X_train, backend)
1320
+ yb = _to_backend_float64(y_train, backend).reshape(-1)
1321
+ alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
1322
+ n_samples, n_features = Xb.shape
1323
+
1324
+ from statgpu.backends._utils import xp_ones as _xp_ones_fn
1325
+ _ones = _xp_ones_fn((n_samples, 1), dtype=Xb.dtype, xp=xp, ref_arr=Xb)
1326
+ X_work = xp.concatenate([Xb, _ones], axis=1)
1327
+
1328
+ if X_val is not None and y_val is not None:
1329
+ Xv = _to_backend_float64(X_val, backend)
1330
+ yv = _to_backend_float64(y_val, backend).reshape(-1)
1331
+ n_val = Xv.shape[0]
1332
+ _ones_v = _xp_ones_fn((n_val, 1), dtype=Xv.dtype, xp=xp, ref_arr=Xv)
1333
+ X_val_work = xp.concatenate([Xv, _ones_v], axis=1)
1334
+ else:
1335
+ X_val_work = yv = swv = None
1336
+
1337
+ if X_val is not None and y_val is not None and val_sample_weight is not None:
1338
+ swv = _to_backend_float64(val_sample_weight, backend).reshape(-1)
1339
+ else:
1340
+ swv = None
1341
+
1342
+ sw_fit = (
1343
+ _to_backend_float64(sample_weight, backend)
1344
+ if sample_weight is not None
1345
+ else None
1346
+ )
1347
+ loss_fn = _resolve_loss_name(loss_name, loss_kwargs=loss_kwargs)
1348
+ if penalty_name in ("elasticnet", "en"):
1349
+ base_penalty = get_penalty("elasticnet", alpha=float(alphas[0]), l1_ratio=float(l1_ratio))
1350
+ else:
1351
+ base_penalty = get_penalty("l1", alpha=float(alphas[0]))
1352
+ penalty = _FeatureOnlySparsePenalty(base_penalty, n_features, backend)
1353
+
1354
+ lipschitz_L = None
1355
+ if not getattr(loss_fn, "_lipschitz_at_init", False):
1356
+ try:
1357
+ zero_lip = _zeros(n_features + 1, backend, ref_tensor=X_work)
1358
+ lipschitz_L = float(_to_numpy(loss_fn.lipschitz(X_work, zero_lip, y=yb)))
1359
+ if not np.isfinite(lipschitz_L) or lipschitz_L <= 0.0:
1360
+ lipschitz_L = None
1361
+ except Exception:
1362
+ lipschitz_L = None
1363
+
1364
+ scores = []
1365
+ score_params_path = []
1366
+ coef_path = []
1367
+ intercept_path = []
1368
+ iters = []
1369
+ if backend == "torch":
1370
+ import torch
1371
+ y_mean = max(float(torch.mean(yb).item()), 1e-3)
1372
+ elif backend == "cupy":
1373
+ import cupy as cp
1374
+ y_mean = max(float(cp.mean(yb)), 1e-3)
1375
+ else:
1376
+ y_mean = max(float(np.mean(yb)), 1e-3)
1377
+ # Use the correct link-function inverse for intercept initialization:
1378
+ # logistic -> logit link: log(y_mean / (1 - y_mean))
1379
+ # poisson/gamma/tweedie/nb -> log link: log(y_mean)
1380
+ if loss_name == "logistic":
1381
+ y_mean_clipped = np.clip(y_mean, 1e-7, 1.0 - 1e-7)
1382
+ init_intercept = np.log(y_mean_clipped / (1.0 - y_mean_clipped))
1383
+ else:
1384
+ init_intercept = np.log(y_mean)
1385
+ init = _zeros(n_features + 1, backend, ref_tensor=X_work)
1386
+ init[-1] = init_intercept
1387
+ solver_name = str(solver_name).lower()
1388
+ solver_fn = fista_bb_solver if solver_name == "fista_bb" else fista_solver
1389
+ for alpha in alphas:
1390
+ base_penalty.alpha = float(alpha)
1391
+ solver_kwargs = {
1392
+ "max_iter": int(max_iter),
1393
+ "tol": tol,
1394
+ "init_coef": init,
1395
+ "sample_weight": sw_fit,
1396
+ }
1397
+ if lipschitz_L is not None:
1398
+ solver_kwargs["lipschitz_L"] = lipschitz_L
1399
+ if solver_fn is fista_solver or solver_name == "fista_bb":
1400
+ solver_kwargs["cv_mode"] = bool(cv_mode)
1401
+ params, n_iter = solver_fn(
1402
+ loss_fn,
1403
+ penalty,
1404
+ X_work,
1405
+ yb,
1406
+ **solver_kwargs,
1407
+ )
1408
+ init = params
1409
+ if X_val_work is not None:
1410
+ if backend == "torch":
1411
+ score_params_path.append(params.clone())
1412
+ elif backend == "cupy":
1413
+ score_params_path.append(params.copy())
1414
+ else:
1415
+ # NumPy path: compute validation loss
1416
+ if swv is not None:
1417
+ # Weighted loss path
1418
+ yv_np = np.asarray(_to_numpy(yv), dtype=np.float64).ravel()
1419
+ sw_np = np.asarray(_to_numpy(swv), dtype=np.float64).ravel()
1420
+ Xv_np = np.asarray(_to_numpy(Xv), dtype=np.float64) if Xv is not None else None
1421
+ params_np = np.asarray(_to_numpy(params), dtype=np.float64).ravel()
1422
+ val = _evaluate_loss_numpy(loss_name, loss_fn,
1423
+ Xv_np, yv_np,
1424
+ params_np[:n_features],
1425
+ float(params_np[n_features]),
1426
+ True, sample_weight=sw_np)
1427
+ else:
1428
+ val = float(loss_fn.value(X_val_work, yv, params))
1429
+ score_params_path.append(val)
1430
+ if return_path:
1431
+ params_np = np.asarray(_to_numpy(params), dtype=np.float64).ravel()
1432
+ coef_path.append(params_np[:n_features].copy())
1433
+ intercept_path.append(float(params_np[n_features]))
1434
+ iters.append(int(n_iter))
1435
+
1436
+ # GPU backends: compute per-sample validation loss via registry,
1437
+ # then aggregate across samples (weighted or unweighted).
1438
+ # The registry functions work with any shape (1D or 2D batched eta).
1439
+ if score_params_path:
1440
+ _loss_params = {}
1441
+ if loss_name == "negative_binomial":
1442
+ _loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
1443
+ elif loss_name == "tweedie":
1444
+ _loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
1445
+
1446
+ if backend in ("torch", "cupy"):
1447
+ params_mat = xp.stack(score_params_path, axis=1)
1448
+ eta = X_val_work @ params_mat # (n_val, n_alphas)
1449
+ yy = yv.reshape(-1, 1)
1450
+ per_sample = _LOSS_VALLOSS_FNS[loss_name](eta, yy, **_loss_params)
1451
+ if swv is not None:
1452
+ sw_col = swv.reshape(-1, 1)
1453
+ sw_sum = _to_float_scalar(xp.sum(swv))
1454
+ if sw_sum > 0:
1455
+ scores_arr = xp.sum(sw_col * per_sample, axis=0) / sw_sum
1456
+ else:
1457
+ scores_arr = xp.mean(per_sample, axis=0)
1458
+ else:
1459
+ scores_arr = xp.mean(per_sample, axis=0)
1460
+ scores = _to_numpy(scores_arr).tolist()
1461
+ else:
1462
+ scores = [_scalar_to_float(s) for s in score_params_path]
1463
+
1464
+ out = {
1465
+ "scores": np.asarray(scores, dtype=np.float64) if scores else None,
1466
+ "n_iter": np.asarray(iters, dtype=np.int64),
1467
+ }
1468
+ if return_path:
1469
+ out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
1470
+ out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
1471
+ return out
1472
+
1473
+
1474
+ def _scad_mcp_cv_path(
1475
+ loss_name,
1476
+ X_train,
1477
+ y_train,
1478
+ alpha_sorted,
1479
+ penalty_name,
1480
+ l1_ratio,
1481
+ max_iter,
1482
+ tol,
1483
+ device,
1484
+ X_val=None,
1485
+ y_val=None,
1486
+ sample_weight=None,
1487
+ val_sample_weight=None,
1488
+ return_path=False,
1489
+ max_lla_per_step=3,
1490
+ lla_tol=1e-4,
1491
+ loss_kwargs=None,
1492
+ ):
1493
+ """Warm-started SCAD/MCP alpha path for CV.
1494
+
1495
+ For each alpha: compute LLA weights from current coef, run FISTA with
1496
+ AdaptiveL1Penalty(weights=lla_w), warm-start from previous alpha.
1497
+ Avoids per-alpha model.fit() overhead.
1498
+ """
1499
+ loss_name = str(loss_name).lower()
1500
+ penalty_name = str(penalty_name).lower()
1501
+ if penalty_name not in ("scad", "mcp"):
1502
+ return None
1503
+ if not _is_uniform_weight(sample_weight):
1504
+ warnings.warn(
1505
+ "_scad_mcp_cv_path: non-uniform sample_weight not supported, "
1506
+ "falling back to general CV path.",
1507
+ RuntimeWarning,
1508
+ stacklevel=2,
1509
+ )
1510
+ return None
1511
+
1512
+ from statgpu.solvers import fista_solver
1513
+ from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
1514
+ from statgpu.penalties import get_penalty, SCADPenalty, MCPPenalty
1515
+ from statgpu.penalties._adaptive_l1 import AdaptiveL1Penalty
1516
+
1517
+ backend = _backend_name_for_cv_device(device)
1518
+ from statgpu.backends._utils import _get_xp
1519
+ xp = _get_xp(backend)
1520
+ Xb = _to_backend_float64(X_train, backend)
1521
+ yb = _to_backend_float64(y_train, backend).reshape(-1)
1522
+ alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
1523
+ n_samples, n_features = Xb.shape
1524
+
1525
+ # Augment X with intercept column
1526
+ from statgpu.backends._utils import xp_ones as _xp_ones_fn
1527
+ _ones = _xp_ones_fn((n_samples, 1), dtype=Xb.dtype, xp=xp, ref_arr=Xb)
1528
+ X_work = xp.concatenate([Xb, _ones], axis=1)
1529
+
1530
+ # Validation data
1531
+ if X_val is not None and y_val is not None:
1532
+ Xv = _to_backend_float64(X_val, backend)
1533
+ yv = _to_backend_float64(y_val, backend).reshape(-1)
1534
+ n_val = Xv.shape[0]
1535
+ if backend == "torch":
1536
+ ones_v = xp.ones((n_val, 1), dtype=Xv.dtype, device=Xv.device)
1537
+ X_val_work = xp.concatenate([Xv, ones_v], axis=1)
1538
+ elif backend == "cupy":
1539
+ ones_v = xp.ones((n_val, 1), dtype=Xv.dtype)
1540
+ X_val_work = xp.concatenate([Xv, ones_v], axis=1)
1541
+ else:
1542
+ ones_v = np.ones((n_val, 1), dtype=Xv.dtype)
1543
+ X_val_work = np.concatenate([Xv, ones_v], axis=1)
1544
+ else:
1545
+ X_val_work = yv = None
1546
+
1547
+ # Validation sample weights
1548
+ if val_sample_weight is not None and X_val_work is not None:
1549
+ swv = _to_backend_float64(val_sample_weight, backend).reshape(-1)
1550
+ else:
1551
+ swv = None
1552
+
1553
+ loss_fn = _resolve_loss_name(loss_name, loss_kwargs=loss_kwargs)
1554
+
1555
+ # Create SCAD/MCP penalty object
1556
+ if penalty_name == "scad":
1557
+ scad_penalty = SCADPenalty(alpha=float(alphas[0]))
1558
+ else:
1559
+ scad_penalty = MCPPenalty(alpha=float(alphas[0]))
1560
+
1561
+ # Precompute XtX and Lipschitz for squared_error
1562
+ _is_quadratic = (loss_name == "squared_error")
1563
+ X_mean = None
1564
+ y_mean = None
1565
+ if _is_quadratic:
1566
+ X_mean = xp.mean(X_work[:, :n_features], axis=0)
1567
+ y_mean = xp.mean(yb)
1568
+ Xc = X_work[:, :n_features] - X_mean
1569
+ yc = yb - y_mean
1570
+ XtX = Xc.T @ Xc / n_samples
1571
+ Xty = Xc.T @ yc / n_samples
1572
+ eig_max = _max_eigval_power(XtX)
1573
+ L_base = max(eig_max * 1.01, 1.0) # small safety factor for numerical stability
1574
+ else:
1575
+ # For GLM losses, compute Lipschitz from loss
1576
+ _zero = _zeros(n_features + 1, backend, ref_tensor=Xb)
1577
+ L_base = float(_to_numpy(loss_fn.lipschitz(X_work, _zero, y=yb)))
1578
+ _safety = getattr(loss_fn, '_lipschitz_safety', 1.0)
1579
+ if _safety > 1.0:
1580
+ L_base *= _safety
1581
+
1582
+ scores = []
1583
+ scores_dev = []
1584
+ coef_path = []
1585
+ intercept_path = []
1586
+ iters = []
1587
+ L_glm = None # Lipschitz constant for GLM losses (computed once)
1588
+
1589
+ # Pre-build loss-specific params (avoid dict construction in loop)
1590
+ _loss_params = {}
1591
+ if loss_name == "negative_binomial":
1592
+ _loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
1593
+ elif loss_name == "tweedie":
1594
+ _loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
1595
+
1596
+ # Initialize coef (warm-start from zeros or previous fold)
1597
+ coef = _zeros(n_features + 1, backend, ref_tensor=Xb)
1598
+
1599
+ # Pre-create inner penalty object (reuse across LLA iterations)
1600
+ inner_pen = AdaptiveL1Penalty(alpha=1.0)
1601
+
1602
+ for alpha in alphas:
1603
+ scad_penalty.alpha = float(alpha)
1604
+
1605
+ # LLA outer loop
1606
+ for lla_iter in range(max_lla_per_step):
1607
+ # Compute LLA weights from current coef (features only, intercept gets 0)
1608
+ lla_w_feat = scad_penalty.lla_weights(coef[:n_features])
1609
+ _zero_scalar = _zeros(1, backend, ref_tensor=coef)
1610
+ lla_w = xp.concatenate([lla_w_feat, _zero_scalar])
1611
+
1612
+ # Update weights in-place (avoid object creation overhead)
1613
+ inner_pen._weights = lla_w
1614
+
1615
+ coef_before_lla = _copy_arr(coef)
1616
+ iteration = -1 # default if max_iter=0
1617
+
1618
+ # FISTA inner solve with warm-start
1619
+ # Cap iterations for CV to keep SCAD/MCP paths fast
1620
+ _inner_max_iter = min(int(max_iter), _FISTA_MAX_ITER_CV)
1621
+ if _is_quadratic:
1622
+ # Squared error: use precomputed XtX
1623
+ step = 1.0 / L_base
1624
+ y_k = _copy_arr(coef)
1625
+ t_k = 1.0
1626
+ for iteration in range(_inner_max_iter):
1627
+ coef_old = _copy_arr(coef)
1628
+ grad = XtX @ y_k[:n_features] - Xty
1629
+ grad_full = xp.concatenate([grad, _zeros(1, backend, ref_tensor=grad)])
1630
+
1631
+ w = y_k - step * grad_full
1632
+ coef = inner_pen.proximal(w, step, backend=backend)
1633
+
1634
+ beta_mom, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
1635
+ y_k = coef + beta_mom * (coef - coef_old)
1636
+
1637
+ # Convergence check (device-side, every 10 iters for CV)
1638
+ if iteration % 10 == 0 and iteration > 0:
1639
+ delta = _abs_sum_dev(coef - coef_old)
1640
+ if _device_gt(tol, delta):
1641
+ break
1642
+ else:
1643
+ # GLM loss: direct FISTA loop with device-side convergence.
1644
+ # Precompute Lipschitz constant once (reuse across alphas).
1645
+ if L_glm is None:
1646
+ _zero = _zeros(n_features + 1, backend, ref_tensor=Xb)
1647
+ L_glm = float(_to_numpy(loss_fn.lipschitz(X_work, _zero, y=yb)))
1648
+ _safety = getattr(loss_fn, '_lipschitz_safety', 1.0)
1649
+ if _safety > 1.0:
1650
+ L_glm *= _safety
1651
+ # Y-scaling for exp-link families
1652
+ _loss_name_inner = getattr(loss_fn, 'name', '')
1653
+ _skip_ys = getattr(loss_fn, '_lipschitz_uses_y', False)
1654
+ if _loss_name_inner not in ('squared_error',) and not _skip_ys:
1655
+ _y_abs = np.abs(_to_numpy(yb))
1656
+ _y_mean = float(np.mean(_y_abs))
1657
+ _y_max = float(np.max(_y_abs))
1658
+ _y_scale = min(10.0, max(1.0, np.sqrt(_y_mean * _y_max)))
1659
+ if _y_scale > 1.0:
1660
+ L_glm *= _y_scale
1661
+ L_glm = max(L_glm, 1.0)
1662
+
1663
+ step = 1.0 / L_glm
1664
+ y_k = _copy_arr(coef)
1665
+ t_k = 1.0
1666
+ for iteration in range(_inner_max_iter):
1667
+ coef_old = _copy_arr(coef)
1668
+
1669
+ # Gradient: loss.gradient(X, y, coef)
1670
+ grad = loss_fn.gradient(X_work, yb, y_k)
1671
+
1672
+ # Proximal step
1673
+ w = y_k - step * grad
1674
+ coef = inner_pen.proximal(w, step, backend=backend)
1675
+
1676
+ # Momentum
1677
+ beta_mom, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
1678
+ y_k = coef + beta_mom * (coef - coef_old)
1679
+
1680
+ # Convergence check (device-side, every 10 iters for CV)
1681
+ if iteration % 10 == 0 and iteration > 0:
1682
+ delta = _abs_sum_dev(coef - coef_old)
1683
+ if _device_gt(tol, delta):
1684
+ break
1685
+
1686
+ # LLA convergence check
1687
+ delta = _abs_sum_dev(coef - coef_before_lla)
1688
+ if _device_gt(lla_tol, delta):
1689
+ break
1690
+
1691
+ # Extract coef and compute intercept from centered-data fit.
1692
+ # For squared_error, the FISTA loop works on centered X/y, so
1693
+ # coef[n_features] stays at zero. Compute the correct intercept.
1694
+ if backend == "torch":
1695
+ coef_feat = coef[:n_features]
1696
+ coef_np = coef_feat.detach().cpu().numpy()
1697
+ elif backend == "cupy":
1698
+ coef_feat = coef[:n_features]
1699
+ coef_np = coef_feat.get()
1700
+ else:
1701
+ coef_feat = coef[:n_features]
1702
+ coef_np = coef_feat.copy()
1703
+
1704
+ if _is_quadratic and X_mean is not None:
1705
+ # intercept = y_mean - X_mean @ coef_features (from centering)
1706
+ intercept = float(y_mean - float(_to_numpy(xp.dot(X_mean, coef_feat))))
1707
+ # Update coef[n_features] so validation uses correct intercept
1708
+ coef[n_features] = intercept
1709
+ else:
1710
+ if backend == "torch":
1711
+ intercept = float(coef[n_features].item())
1712
+ elif backend == "cupy":
1713
+ intercept = float(coef[n_features].get())
1714
+ else:
1715
+ intercept = float(coef[n_features])
1716
+
1717
+ # Validation loss on device (weighted if val_sample_weight provided)
1718
+ if X_val_work is not None:
1719
+ if swv is not None:
1720
+ # Per-sample weighted loss
1721
+ eta_v = X_val_work @ coef
1722
+ if loss_name == "squared_error":
1723
+ per_sample = (yv - eta_v) ** 2
1724
+ else:
1725
+ per_sample = _LOSS_VALLOSS_FNS[loss_name](eta_v, yv, **_loss_params)
1726
+ sw_sum = _to_float_scalar(xp.sum(swv))
1727
+ val_loss = _to_float_scalar(xp.sum(swv * per_sample)) / max(sw_sum, 1e-15)
1728
+ else:
1729
+ val_loss = loss_fn.value(X_val_work, yv, coef)
1730
+ # Normalize to Python float to avoid mixing types in scores_dev
1731
+ scores_dev.append(float(val_loss) if not isinstance(val_loss, float) else val_loss)
1732
+
1733
+ if return_path:
1734
+ coef_path.append(coef_np)
1735
+ intercept_path.append(intercept)
1736
+ iters.append(iteration + 1)
1737
+
1738
+ # Batch sync validation scores (all values are Python floats)
1739
+ if scores_dev:
1740
+ scores = [float(s) for s in scores_dev]
1741
+
1742
+ out = {
1743
+ "scores": np.asarray(scores, dtype=np.float64) if scores else None,
1744
+ "n_iter": np.asarray(iters, dtype=np.int64),
1745
+ }
1746
+ if return_path:
1747
+ out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
1748
+ out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
1749
+ return out
1750
+
1751
+
1752
+ # ---------------------------------------------------------------------------
1753
+ # Data-driven GPU device selection thresholds for CV
1754
+ # ---------------------------------------------------------------------------
1755
+ # Each entry: (loss, penalties, min_nx, min_features, reason, or_min_features)
1756
+ # - loss: loss name or None (matches any non-squared-error)
1757
+ # - penalties: tuple of penalty names that trigger GPU evaluation
1758
+ # - min_nx: minimum n_samples * n_features to consider GPU
1759
+ # - min_features: minimum n_features (0 = no feature threshold)
1760
+ # - reason: explanation string for the auto-selection decision
1761
+ # - or_min_features: if >0, also allow GPU when n_features >= or_min_features
1762
+ # AND n_samples*n_features >= 1_000_000 (OR with primary cond)
1763
+ # If the condition is met and torch.cuda is available → "torch", else → "cpu".
1764
+ _CV_DEVICE_THRESHOLDS = [
1765
+ ("squared_error", ("l1", "elasticnet", "en"), 1_000_000, 256,
1766
+ "medium squared-error sparse CV benefits from batched torch alpha path", 0),
1767
+ (None, ("scad", "mcp"), 1_000_000, 0,
1768
+ "large GLM SCAD/MCP CV benefits from torch async FISTA", 0),
1769
+ ("logistic", ("l1", "elasticnet", "en"), 1_000_000, 500,
1770
+ "high-dimensional logistic sparse CV benefits from torch", 0),
1771
+ ("logistic", ("l1", "elasticnet", "en"), 500_000, 100,
1772
+ "medium logistic sparse CV benefits from fold-batched torch path", 0),
1773
+ ("poisson", ("l1", "elasticnet", "en"), 1_000_000, 500,
1774
+ "high-dimensional poisson sparse CV benefits from torch", 0),
1775
+ ("gamma", ("l1", "elasticnet", "en"), 2_000_000, 500,
1776
+ "large high-dimensional gamma sparse CV benefits from torch", 0),
1777
+ ("inverse_gaussian", ("l1", "elasticnet", "en"), 2_000_000, 500,
1778
+ "large high-dimensional inverse-gaussian sparse CV benefits from torch", 0),
1779
+ ("tweedie", ("l1", "elasticnet", "en"), 300_000, 0,
1780
+ "medium tweedie sparse CV is faster on torch", 0),
1781
+ ]
1782
+ # Special: always-CPU losses (regardless of problem size)
1783
+ _CV_DEVICE_ALWAYS_CPU = {
1784
+ "negative_binomial": "negative-binomial CV is faster on CPU for current benchmarked sizes",
1785
+ }
1786
+
1787
+
1788
+ class PenalizedGLM_CV(CVEstimatorBase):
1789
+ """Cross-validated penalized GLM supporting all loss + penalty combinations."""
1790
+
1791
+ def __init__(
1792
+ self,
1793
+ loss: str = 'squared_error',
1794
+ penalty: str = 'l2',
1795
+ alpha_grid=None,
1796
+ n_alphas: int = 100,
1797
+ l1_ratio: float = 0.5,
1798
+ cv: int = 5,
1799
+ cv_splits=None,
1800
+ random_state: Optional[int] = 0,
1801
+ device: Union[str, Device] = Device.AUTO,
1802
+ max_iter: int = 1000,
1803
+ tol: float = 1e-4,
1804
+ solver: str = 'auto',
1805
+ cv_strategy: str = "strict",
1806
+ acknowledge_approx: bool = False,
1807
+ refine_top_k: int = 3,
1808
+ loss_kwargs: Optional[dict] = None,
1809
+ ):
1810
+ super().__init__(cv=cv, random_state=random_state, device=device)
1811
+ self.cv_splits = cv_splits
1812
+ cv_strategy = str(cv_strategy).lower()
1813
+ if cv_strategy not in ("strict", "two_stage"):
1814
+ raise ValueError(
1815
+ "cv_strategy must be either 'strict' or 'two_stage', "
1816
+ f"got {cv_strategy!r}."
1817
+ )
1818
+ if int(refine_top_k) < 1:
1819
+ raise ValueError("refine_top_k must be a positive integer")
1820
+ self.loss = loss
1821
+ self._loss_kwargs = loss_kwargs or {}
1822
+ self.penalty = penalty
1823
+ self._alpha_grid_input = alpha_grid
1824
+ self.n_alphas = n_alphas
1825
+ self.l1_ratio = l1_ratio
1826
+ self.max_iter = max_iter
1827
+ self.tol = tol
1828
+ self.solver = solver
1829
+ self.cv_strategy = cv_strategy
1830
+ self.acknowledge_approx = bool(acknowledge_approx)
1831
+ self.refine_top_k = int(refine_top_k)
1832
+
1833
+ self.alpha_ = None
1834
+ self.alpha_grid_ = None
1835
+ self.cv_strategy_ = None
1836
+ self.cv_selected_device_ = None
1837
+ self._cv_auto_reason_ = None
1838
+
1839
+ def _solver_for_cv(self, cv_device=None, X=None):
1840
+ """Return the strict internal solver used by the CV loop."""
1841
+ solver = str(self.solver).lower()
1842
+ if solver != "auto":
1843
+ return solver
1844
+ from statgpu.linear_model.penalized._fit_mixin import _preferred_penalized_glm_solver
1845
+
1846
+ return _preferred_penalized_glm_solver(
1847
+ self.loss,
1848
+ self.penalty,
1849
+ backend_name=_backend_name_for_cv_device(
1850
+ self.device if cv_device is None else cv_device
1851
+ ),
1852
+ l1_ratio=self.l1_ratio,
1853
+ cv_mode=True,
1854
+ problem_size=None if X is None else int(X.shape[0]) * int(X.shape[1]),
1855
+ )
1856
+
1857
+ def _effective_cv_device(self, X, penalty_name, n_alphas):
1858
+ """Resolve device for CV-level work; explicit devices are untouched."""
1859
+ self.cv_selected_device_ = self.device
1860
+ self._cv_auto_reason_ = None
1861
+ if _device_to_name(self.device) != "auto":
1862
+ return self.device
1863
+
1864
+ n_samples, n_features = X.shape
1865
+ penalty_name = str(penalty_name).lower()
1866
+ loss_name = str(self.loss).lower()
1867
+ nx = int(n_samples) * int(n_features)
1868
+
1869
+ # Small problems: always CPU
1870
+ if nx < _SMALL_PROBLEM_THRESHOLD:
1871
+ self.cv_selected_device_ = "cpu"
1872
+ self._cv_auto_reason_ = "small CV problem is faster on CPU"
1873
+ return "cpu"
1874
+
1875
+ # Always-CPU losses
1876
+ if loss_name in _CV_DEVICE_ALWAYS_CPU and penalty_name in ("l2", "l1", "elasticnet", "en"):
1877
+ self.cv_selected_device_ = "cpu"
1878
+ self._cv_auto_reason_ = _CV_DEVICE_ALWAYS_CPU[loss_name]
1879
+ return "cpu"
1880
+
1881
+ # Data-driven threshold lookup
1882
+ for rule_loss, rule_penalties, min_nx, min_features, reason, or_min_feat in _CV_DEVICE_THRESHOLDS:
1883
+ loss_match = (rule_loss is None and loss_name != "squared_error") or rule_loss == loss_name
1884
+ if loss_match and penalty_name in rule_penalties:
1885
+ # Primary condition: nx >= min_nx AND n_features >= min_features
1886
+ cond = nx >= min_nx and int(n_features) >= min_features
1887
+ # OR condition: n_features >= or_min_feat AND nx >= 1_000_000
1888
+ if not cond and or_min_feat > 0:
1889
+ cond = int(n_features) >= or_min_feat and nx >= 1_000_000
1890
+ if cond and _torch_cuda_available():
1891
+ self.cv_selected_device_ = "torch"
1892
+ self._cv_auto_reason_ = reason
1893
+ return "torch"
1894
+ self.cv_selected_device_ = "cpu"
1895
+ self._cv_auto_reason_ = reason.replace("benefits from", "is faster on CPU below break-even for")
1896
+ return "cpu"
1897
+
1898
+ # Fallback: large effective work → GPU
1899
+ continuation_factor = 20 if loss_name != "squared_error" and penalty_name in ("scad", "mcp") else 1
1900
+ effective_work = nx * int(self.cv) * int(n_alphas) * continuation_factor
1901
+ if effective_work < _GPU_BREAK_EVEN_THRESHOLD:
1902
+ self.cv_selected_device_ = "cpu"
1903
+ self._cv_auto_reason_ = "CV effective work is below GPU break-even"
1904
+ return "cpu"
1905
+
1906
+ # Resolve device: if AUTO, prefer torch when CUDA available, else cpu
1907
+ try:
1908
+ import torch
1909
+ if torch.cuda.is_available():
1910
+ self.cv_selected_device_ = "torch"
1911
+ self._cv_auto_reason_ = "GPU selected for large CV effective work"
1912
+ return "torch"
1913
+ except ImportError:
1914
+ pass
1915
+ try:
1916
+ import cupy
1917
+ self.cv_selected_device_ = "cupy"
1918
+ self._cv_auto_reason_ = "GPU selected for large CV effective work"
1919
+ return "cupy"
1920
+ except ImportError:
1921
+ pass
1922
+ self.cv_selected_device_ = "cpu"
1923
+ self._cv_auto_reason_ = "No GPU available, falling back to CPU"
1924
+ return "cpu"
1925
+
1926
+ def _generate_alpha_grid(self, X, y):
1927
+ """Auto-generate alpha grid based on loss and penalty type."""
1928
+ from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
1929
+
1930
+ X_np = _to_numpy(X).astype(np.float64)
1931
+ y_np = _to_numpy(y).astype(np.float64).ravel()
1932
+ n = X_np.shape[0]
1933
+
1934
+ if self.loss == 'squared_error':
1935
+ # Gradient at null model (intercept = mean(y)): X'(y - mean(y)) / n
1936
+ alpha_max = float(np.max(np.abs(X_np.T @ (y_np - np.mean(y_np))))) / n
1937
+ elif self.loss == 'logistic':
1938
+ # Null model prediction: mu_null = mean(y)
1939
+ mu_null = np.mean(y_np)
1940
+ alpha_max = float(np.max(np.abs(X_np.T @ (y_np - mu_null)))) / n
1941
+ else:
1942
+ try:
1943
+ model = PenalizedGeneralizedLinearModel(
1944
+ loss=self.loss, penalty='l2', alpha=0.0,
1945
+ device='cpu', compute_inference=False, max_iter=5,
1946
+ loss_kwargs=getattr(self, '_loss_kwargs', None),
1947
+ )
1948
+ model.fit(X_np, y_np)
1949
+ grad = X_np.T @ (y_np - _to_numpy(model.predict(X_np))) / n
1950
+ alpha_max = float(np.max(np.abs(grad)))
1951
+ except Exception as e:
1952
+ warnings.warn(
1953
+ f"Alpha grid estimation failed ({e}), using alpha_max=1.0",
1954
+ RuntimeWarning,
1955
+ stacklevel=2,
1956
+ )
1957
+ alpha_max = 1.0
1958
+
1959
+ # For elasticnet, the L1 component threshold is alpha*l1_ratio,
1960
+ # so alpha_max should be scaled by 1/l1_ratio
1961
+ if self.penalty == 'elasticnet' and hasattr(self, 'l1_ratio'):
1962
+ _l1r = max(float(self.l1_ratio), 1e-10)
1963
+ alpha_max = alpha_max / _l1r
1964
+
1965
+ if alpha_max <= 0:
1966
+ warnings.warn(
1967
+ f"Alpha grid estimation returned {alpha_max}, using alpha_max=1.0",
1968
+ RuntimeWarning,
1969
+ stacklevel=2,
1970
+ )
1971
+ alpha_max = 1.0
1972
+
1973
+ grid = np.geomspace(alpha_max, max(alpha_max * 1e-4, 1e-12), self.n_alphas)
1974
+ return grid
1975
+
1976
+ def _solve_ridge_fold_batch(self, X_train, y_train, X_val, y_val, alphas):
1977
+ """Batch solve Ridge CV for all alphas using eigendecomposition."""
1978
+ X_train_np = _to_numpy(X_train).astype(np.float64)
1979
+ y_train_np = _to_numpy(y_train).astype(np.float64).ravel()
1980
+ X_val_np = _to_numpy(X_val).astype(np.float64)
1981
+ y_val_np = _to_numpy(y_val).astype(np.float64).ravel()
1982
+ alphas_np = _to_numpy(alphas).astype(np.float64).ravel()
1983
+ return _ridge_eig_batch(X_train_np, y_train_np, X_val_np, y_val_np, alphas_np)
1984
+
1985
+ def _evaluate_single(self, model, X_val, y_val, loss_fn=None, X_val_np=None, y_val_np=None, sample_weight=None):
1986
+ """Evaluate a fitted model on validation data, return validation loss.
1987
+
1988
+ Parameters
1989
+ ----------
1990
+ loss_fn : optional, pre-resolved loss function (avoids repeated import)
1991
+ X_val_np, y_val_np : optional, pre-cached numpy validation data (avoids D2H)
1992
+ sample_weight : optional, per-sample weights for weighted validation loss
1993
+ """
1994
+ from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
1995
+
1996
+ if loss_fn is None:
1997
+ loss_fn = _resolve_loss_name(self.loss)
1998
+ if X_val_np is None:
1999
+ X_val_np = _to_numpy(X_val).astype(np.float64)
2000
+ if y_val_np is None:
2001
+ y_val_np = _to_numpy(y_val).astype(np.float64).ravel()
2002
+ n_val = X_val_np.shape[0]
2003
+
2004
+ try:
2005
+ val_loss = _evaluate_loss_numpy(
2006
+ self.loss,
2007
+ loss_fn,
2008
+ X_val_np,
2009
+ y_val_np,
2010
+ _to_numpy(model.coef_).ravel(),
2011
+ float(model.intercept_),
2012
+ model.fit_intercept,
2013
+ sample_weight=sample_weight,
2014
+ )
2015
+ except Exception:
2016
+ # Fallback: use loss_fn.value() for correct loss, not raw MSE
2017
+ try:
2018
+ if model.fit_intercept:
2019
+ X_design = np.column_stack([np.ones(n_val), X_val_np])
2020
+ coef_full = np.concatenate([[float(model.intercept_)], _to_numpy(model.coef_).ravel()])
2021
+ else:
2022
+ X_design = X_val_np
2023
+ coef_full = _to_numpy(model.coef_).ravel()
2024
+ val_loss = float(loss_fn.value(X_design, y_val_np, coef_full))
2025
+ except Exception:
2026
+ y_pred_np = _to_numpy(model.predict(X_val_np)).ravel()
2027
+ val_loss = float(np.mean((y_val_np - y_pred_np) ** 2))
2028
+ warnings.warn(
2029
+ f"_evaluate_single: loss evaluation failed for '{self.loss}', "
2030
+ f"falling back to MSE. CV scores may be inaccurate for non-Gaussian losses.",
2031
+ RuntimeWarning,
2032
+ stacklevel=2,
2033
+ )
2034
+
2035
+ return val_loss
2036
+
2037
+ @staticmethod
2038
+ def _populate_refit_model(model, coef, intercept, X, device, n_iter=None):
2039
+ """Set standard attributes on a refit model from path results."""
2040
+ model.coef_ = np.asarray(coef, dtype=np.float64)
2041
+ model.intercept_ = float(intercept)
2042
+ if n_iter is not None:
2043
+ model.n_iter_ = int(n_iter)
2044
+ model._params = np.concatenate([[float(intercept)], np.asarray(coef, dtype=np.float64)])
2045
+ model._nobs = int(X.shape[0])
2046
+ n_params = int(X.shape[1]) + (1 if bool(getattr(model, 'fit_intercept', True)) else 0)
2047
+ model._df_resid = int(X.shape[0]) - n_params
2048
+ model._selected_backend_name = _backend_name_for_cv_device(device)
2049
+ model._fitted = True
2050
+ return model
2051
+
2052
+ def _refit_best(self, X, y, best_alpha, sample_weight=None):
2053
+ """Refit on full data with best alpha.
2054
+
2055
+ For squared_error + l2, uses eigendecomposition to match the CV path
2056
+ exactly, avoiding precision mismatch between CV and refit solvers.
2057
+ """
2058
+ from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
2059
+
2060
+ # Resolve refit device (used by Ridge and general paths)
2061
+ refit_device = self.device
2062
+ if _device_to_name(self.device) == "auto":
2063
+ refit_device = getattr(self, "_cv_selected_device_", self.device) or self.device
2064
+
2065
+ # For Ridge: use eigendecomposition to match CV path exactly.
2066
+ # Supports weighted Ridge via weighted eigensolve (same O(p³) cost).
2067
+ if self.loss == 'squared_error' and self.penalty == 'l2':
2068
+ X_np = _to_numpy(X).astype(np.float64)
2069
+ y_np = _to_numpy(y).astype(np.float64).ravel()
2070
+ sw_np = _to_numpy(sample_weight).astype(np.float64).ravel() if sample_weight is not None else None
2071
+ coef, intercept = _ridge_eig_single(X_np, y_np, best_alpha, sample_weight=sw_np)
2072
+ model = PenalizedGeneralizedLinearModel(
2073
+ loss='squared_error', penalty='l2', alpha=best_alpha,
2074
+ device=refit_device, compute_inference=False,
2075
+ max_iter=self.max_iter, tol=self.tol,
2076
+ loss_kwargs=getattr(self, '_loss_kwargs', None),
2077
+ )
2078
+ return self._populate_refit_model(model, coef, intercept, X, refit_device)
2079
+
2080
+ can_infer = (self.loss == 'squared_error' and self.penalty == 'l2')
2081
+ penalty_name = str(self.penalty).lower()
2082
+ alpha_arr = np.asarray([best_alpha], dtype=np.float64)
2083
+
2084
+ # Try specialized refit paths (each returns model or None)
2085
+ refit_paths = []
2086
+ if self.loss == "logistic" and penalty_name in ("l1", "elasticnet", "en"):
2087
+ refit_paths.append(lambda: _logistic_sparse_cv_path(
2088
+ X, y, alpha_arr, penalty_name, self.l1_ratio,
2089
+ _logistic_sparse_effective_max_iter(self.max_iter, refit_device, penalty_name, refit=True),
2090
+ self.tol, refit_device, sample_weight=sample_weight,
2091
+ ))
2092
+ if self.loss == "squared_error" and penalty_name in ("l1", "elasticnet", "en"):
2093
+ refit_paths.append(lambda: _squared_error_sparse_cv_path(
2094
+ X, y, alpha_arr, penalty_name, self.l1_ratio,
2095
+ self.max_iter, self.tol, refit_device, sample_weight=sample_weight,
2096
+ ))
2097
+ cv_solver = self._solver_for_cv(refit_device, X=X)
2098
+ if self._uses_glm_sparse_path(penalty_name, cv_solver):
2099
+ refit_paths.append(lambda: _glm_sparse_cv_path(
2100
+ self.loss, X, y, alpha_arr, penalty_name, self.l1_ratio,
2101
+ self.max_iter, self.tol, refit_device,
2102
+ return_path=True, solver_name=cv_solver, cv_mode=False,
2103
+ sample_weight=sample_weight,
2104
+ ))
2105
+
2106
+ for get_path in refit_paths:
2107
+ path = get_path()
2108
+ if path is not None:
2109
+ model = PenalizedGeneralizedLinearModel(
2110
+ loss=self.loss, penalty=self.penalty, alpha=best_alpha,
2111
+ l1_ratio=self.l1_ratio, device=refit_device,
2112
+ compute_inference=False, max_iter=self.max_iter,
2113
+ tol=self.tol, solver=cv_solver,
2114
+ loss_kwargs=getattr(self, '_loss_kwargs', None),
2115
+ )
2116
+ return self._populate_refit_model(
2117
+ model, path["coef"][-1], path["intercept"][-1],
2118
+ X, refit_device, n_iter=path["n_iter"][-1],
2119
+ )
2120
+
2121
+ # General fallback: model.fit()
2122
+ model = PenalizedGeneralizedLinearModel(
2123
+ loss=self.loss, penalty=self.penalty, alpha=best_alpha,
2124
+ l1_ratio=self.l1_ratio, device=refit_device,
2125
+ compute_inference=can_infer, max_iter=self.max_iter,
2126
+ tol=self.tol, solver=cv_solver,
2127
+ loss_kwargs=getattr(self, '_loss_kwargs', None),
2128
+ )
2129
+ model.fit(X, y, sample_weight=sample_weight)
2130
+ return model
2131
+
2132
+ def _uses_glm_sparse_path(self, penalty_name, cv_solver):
2133
+ penalty_name = str(penalty_name).lower()
2134
+ cv_solver = str(cv_solver).lower()
2135
+ return (
2136
+ (
2137
+ (self.loss == "poisson" and penalty_name in ("l1", "elasticnet", "en"))
2138
+ or self.loss in ("gamma", "inverse_gaussian", "tweedie")
2139
+ or (self.loss == "negative_binomial" and cv_solver == "fista_bb")
2140
+ )
2141
+ and penalty_name in ("l1", "elasticnet", "en")
2142
+ and cv_solver in ("auto", "fista", "fista_bb")
2143
+ )
2144
+
2145
+ def _best_index_from_scores(self, mean_scores, alpha_grid, cv_solver):
2146
+ penalty_name = str(self.penalty).lower()
2147
+ loss_name = str(self.loss).lower()
2148
+ if loss_name == "poisson" and penalty_name in ("l1", "elasticnet", "en"):
2149
+ # Poisson sparse CV curves can be nearly flat at the low-alpha end.
2150
+ # CPU/CuPy/Torch validation scores may differ at ~1e-7 from
2151
+ # backend-level summation order, so treat those as ties and keep
2152
+ # selection deterministic toward stronger regularization.
2153
+ return _nanargmin_prefer_larger_alpha(
2154
+ mean_scores,
2155
+ alpha_grid,
2156
+ rel_tol=5e-7,
2157
+ abs_tol=1e-6,
2158
+ )
2159
+ if self._uses_glm_sparse_path(penalty_name, cv_solver):
2160
+ return _nanargmin_prefer_larger_alpha(
2161
+ mean_scores,
2162
+ alpha_grid,
2163
+ rel_tol=5e-6,
2164
+ abs_tol=1e-7,
2165
+ )
2166
+ return _nanargmin_prefer_larger_alpha(mean_scores, alpha_grid)
2167
+
2168
+ def _compute_cv_scores(
2169
+ self,
2170
+ X,
2171
+ y,
2172
+ alpha_grid,
2173
+ cv_device,
2174
+ folds,
2175
+ sample_weight=None,
2176
+ max_iter=None,
2177
+ tol=None,
2178
+ strict=True,
2179
+ ):
2180
+ """Compute CV scores for exactly the supplied alpha grid."""
2181
+ from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
2182
+
2183
+ alpha_grid = np.asarray(alpha_grid, dtype=np.float64).ravel()
2184
+ n_alphas = len(alpha_grid)
2185
+ penalty_name = str(self.penalty).lower()
2186
+ loss_name = str(self.loss).lower()
2187
+ device_name = _device_to_name(cv_device)
2188
+ max_iter = int(self.max_iter if max_iter is None else max_iter)
2189
+ tol = self.tol if tol is None else tol
2190
+
2191
+ # ── Fast path: Ridge eigendecomposition (CPU only, unweighted) ──
2192
+ _is_explicit_gpu = device_name in ("cuda", "torch")
2193
+ if loss_name == "squared_error" and penalty_name == "l2" and sample_weight is None and not _is_explicit_gpu:
2194
+ all_scores = np.full((len(folds), n_alphas), np.nan)
2195
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
2196
+ X_train = _slice_rows(X, train_idx)
2197
+ y_train = _slice_rows(y, train_idx)
2198
+ X_val = _slice_rows(X, val_idx)
2199
+ y_val = _slice_rows(y, val_idx)
2200
+ try:
2201
+ mse, _, _ = self._solve_ridge_fold_batch(
2202
+ X_train, y_train, X_val, y_val, alpha_grid,
2203
+ )
2204
+ all_scores[fold_idx, :] = mse
2205
+ except Exception as e:
2206
+ warnings.warn(
2207
+ f"Ridge eig batch failed for fold {fold_idx}: {e}",
2208
+ RuntimeWarning,
2209
+ stacklevel=2,
2210
+ )
2211
+ return all_scores
2212
+
2213
+ sort_idx = np.argsort(-alpha_grid)
2214
+ alpha_sorted = alpha_grid[sort_idx]
2215
+ all_scores = np.full((len(folds), n_alphas), np.nan)
2216
+ cv_solver = self._solver_for_cv(cv_device, X=X)
2217
+
2218
+ # ── Fast path: fold-batched CV (all folds at once, GPU only) ──
2219
+ use_fold_batch = (
2220
+ not strict
2221
+ and loss_name in _FOLD_BATCH_CONFIGS
2222
+ and penalty_name in ("l1", "elasticnet", "en")
2223
+ and device_name in ("torch", "cuda")
2224
+ )
2225
+ if use_fold_batch:
2226
+ try:
2227
+ path = _glm_sparse_cv_folds(
2228
+ X, y, folds, alpha_sorted, penalty_name, self.l1_ratio,
2229
+ max_iter, tol, loss_name, device_name,
2230
+ sample_weight=sample_weight,
2231
+ loss_kwargs=getattr(self, '_loss_kwargs', None),
2232
+ )
2233
+ if path is not None and path["scores"] is not None:
2234
+ all_scores[:, sort_idx] = path["scores"]
2235
+ return all_scores
2236
+ except Exception as e:
2237
+ warnings.warn(
2238
+ f"Fold-batched {loss_name} sparse CV failed on {device_name}; "
2239
+ f"falling back to per-fold path: {e}",
2240
+ RuntimeWarning,
2241
+ stacklevel=2,
2242
+ )
2243
+
2244
+ # ── Per-fold dispatch table ──
2245
+ # Each entry: (condition_fn, path_fn)
2246
+ # condition_fn(loss_name, penalty_name, cv_solver, strict) -> bool
2247
+ # path_fn(X_train, y_train, alpha_sorted, ..., X_val, y_val, sw_train, sw_val) -> dict or None
2248
+
2249
+ def _cond_scad_mcp(loss_name, penalty_name, cv_solver, strict):
2250
+ return penalty_name in ("scad", "mcp") and (loss_name == "squared_error" or not strict)
2251
+
2252
+ def _path_scad_mcp(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
2253
+ max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
2254
+ return _scad_mcp_cv_path(
2255
+ loss_name, X_train, y_train, alpha_sorted, penalty_name,
2256
+ l1_ratio, max_iter, tol, cv_device,
2257
+ X_val=X_val, y_val=y_val, sample_weight=sw_train,
2258
+ val_sample_weight=sw_val,
2259
+ )
2260
+
2261
+ def _cond_logistic(loss_name, penalty_name, cv_solver, strict):
2262
+ return loss_name == "logistic" and penalty_name in ("l1", "elasticnet", "en")
2263
+
2264
+ def _path_logistic(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
2265
+ max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
2266
+ return _logistic_sparse_cv_path(
2267
+ X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
2268
+ max_iter, tol, cv_device,
2269
+ X_val=X_val, y_val=y_val, sample_weight=sw_train,
2270
+ val_sample_weight=sw_val, return_path=False,
2271
+ )
2272
+
2273
+ def _cond_squared(loss_name, penalty_name, cv_solver, strict):
2274
+ return loss_name == "squared_error" and penalty_name in ("l1", "elasticnet", "en")
2275
+
2276
+ def _path_squared(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
2277
+ max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
2278
+ return _squared_error_sparse_cv_path(
2279
+ X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
2280
+ max_iter, tol, cv_device,
2281
+ X_val=X_val, y_val=y_val, sample_weight=sw_train,
2282
+ val_sample_weight=sw_val, return_path=False,
2283
+ )
2284
+
2285
+ def _cond_glm_sparse(loss_name, penalty_name, cv_solver, strict):
2286
+ return self._uses_glm_sparse_path(penalty_name, cv_solver)
2287
+
2288
+ def _path_glm_sparse(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
2289
+ max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
2290
+ return _glm_sparse_cv_path(
2291
+ loss_name, X_train, y_train, alpha_sorted, penalty_name,
2292
+ l1_ratio, max_iter, tol, cv_device,
2293
+ X_val=X_val, y_val=y_val, sample_weight=sw_train,
2294
+ val_sample_weight=sw_val, return_path=False,
2295
+ solver_name=cv_solver, cv_mode=not strict,
2296
+ )
2297
+
2298
+ _per_fold_paths = [
2299
+ (_cond_scad_mcp, _path_scad_mcp),
2300
+ (_cond_logistic, _path_logistic),
2301
+ (_cond_squared, _path_squared),
2302
+ (_cond_glm_sparse, _path_glm_sparse),
2303
+ ]
2304
+
2305
+ # Pre-check which paths are active for this loss/penalty combo
2306
+ active_paths = [(cond, path_fn) for cond, path_fn in _per_fold_paths
2307
+ if cond(loss_name, penalty_name, cv_solver, strict)]
2308
+
2309
+ # ── Per-fold loop ──
2310
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
2311
+ X_train = _slice_rows(X, train_idx)
2312
+ y_train = _slice_rows(y, train_idx)
2313
+ X_val = _slice_rows(X, val_idx)
2314
+ y_val = _slice_rows(y, val_idx)
2315
+ sw_train = _slice_rows(sample_weight, train_idx) if sample_weight is not None else None
2316
+ sw_val = _slice_rows(sample_weight, val_idx) if sample_weight is not None else None
2317
+
2318
+ # Try each specialized path in order
2319
+ fold_handled = False
2320
+ for cond_fn, path_fn in active_paths:
2321
+ try:
2322
+ path = path_fn(
2323
+ X_train, y_train, alpha_sorted, penalty_name,
2324
+ self.l1_ratio, max_iter, tol, cv_device,
2325
+ X_val=X_val, y_val=y_val,
2326
+ sw_train=sw_train, sw_val=sw_val,
2327
+ )
2328
+ if path is not None and path["scores"] is not None:
2329
+ all_scores[fold_idx, sort_idx] = path["scores"]
2330
+ fold_handled = True
2331
+ break
2332
+ except Exception as e:
2333
+ warnings.warn(
2334
+ f"{path_fn.__name__} failed for {loss_name}+{penalty_name} "
2335
+ f"fold {fold_idx}: {e}",
2336
+ RuntimeWarning,
2337
+ stacklevel=2,
2338
+ )
2339
+ continue
2340
+
2341
+ if fold_handled:
2342
+ continue
2343
+
2344
+ # ── General fallback: model.fit() per alpha ──
2345
+ self._cv_fold_general(
2346
+ all_scores, fold_idx, sort_idx, alpha_sorted,
2347
+ loss_name, cv_device, cv_solver, strict,
2348
+ X_train, y_train, X_val, y_val,
2349
+ sw_train, sw_val, max_iter, tol,
2350
+ )
2351
+
2352
+ return all_scores
2353
+
2354
+ def _cv_fold_general(
2355
+ self, all_scores, fold_idx, sort_idx, alpha_sorted,
2356
+ loss_name, cv_device, cv_solver, strict,
2357
+ X_train, y_train, X_val, y_val,
2358
+ sw_train, sw_val, max_iter, tol,
2359
+ ):
2360
+ """General per-fold CV path: model.fit() per alpha with warm-start."""
2361
+ from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
2362
+ from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
2363
+
2364
+ penalty_name = str(self.penalty).lower()
2365
+ device_name = _device_to_name(cv_device)
2366
+
2367
+ X_val_np = _to_numpy(X_val).astype(np.float64)
2368
+ y_val_np = _to_numpy(y_val).astype(np.float64).ravel()
2369
+ loss_fn = _resolve_loss_name(loss_name)
2370
+
2371
+ # Disable warm-start for SCAD/MCP on non-squared-error losses
2372
+ _is_scad_mcp_non_se = penalty_name in ("scad", "mcp") and loss_name != "squared_error"
2373
+ use_warm_start = not _is_scad_mcp_non_se
2374
+ use_lla_path_cv = (
2375
+ not strict and loss_name != "squared_error" and penalty_name in ("scad", "mcp")
2376
+ )
2377
+
2378
+ # Transfer to GPU if needed
2379
+ if device_name in ("cuda", "torch"):
2380
+ fold_backend = _backend_name_for_cv_device(cv_device)
2381
+ X_train_fit = _to_backend_float64(X_train, fold_backend)
2382
+ y_train_fit = _to_backend_float64(y_train, fold_backend)
2383
+ sw_train_fit = _to_backend_float64(sw_train, fold_backend) if sw_train is not None else None
2384
+ else:
2385
+ X_train_fit = X_train
2386
+ y_train_fit = y_train
2387
+ sw_train_fit = sw_train
2388
+
2389
+ # Precompute XtX/Xty for squared-error GPU cache
2390
+ cv_cache, L_np = self._build_cv_cache(
2391
+ loss_name, device_name, X_train, y_train, sw_train
2392
+ )
2393
+
2394
+ model = PenalizedGeneralizedLinearModel(
2395
+ loss=loss_name, penalty=self.penalty, alpha=alpha_sorted[0],
2396
+ l1_ratio=self.l1_ratio, device=cv_device, compute_inference=False,
2397
+ max_iter=max_iter, tol=tol, solver=cv_solver,
2398
+ )
2399
+ if cv_cache is not None:
2400
+ model._cv_cache = cv_cache
2401
+ model._preserve_cv_cache = True
2402
+ if L_np is not None and L_np > 0:
2403
+ model.lipschitz_L = L_np
2404
+
2405
+ # LLA path for SCAD/MCP
2406
+ if use_lla_path_cv:
2407
+ try:
2408
+ model.alpha = float(alpha_sorted[-1])
2409
+ if hasattr(model, "_penalty") and model._penalty is not None:
2410
+ model._penalty.alpha = float(alpha_sorted[-1])
2411
+ model._cv_alpha_path = np.asarray(alpha_sorted, dtype=np.float64)
2412
+ model.fit(X_train_fit, y_train_fit, sample_weight=sw_train_fit)
2413
+ path = getattr(model, "_cv_path_results", None)
2414
+ if path is not None:
2415
+ path_alphas = np.asarray(path["alpha"], dtype=np.float64)
2416
+ path_coefs = np.asarray(path["coef"], dtype=np.float64)
2417
+ path_intercepts = np.asarray(path["intercept"], dtype=np.float64)
2418
+ for alpha_idx_sorted, alpha in enumerate(alpha_sorted):
2419
+ matches = np.flatnonzero(np.isclose(path_alphas, float(alpha), rtol=1e-10, atol=1e-14))
2420
+ if matches.size == 0:
2421
+ continue
2422
+ path_idx = int(matches[-1])
2423
+ val_loss = _evaluate_loss_numpy(
2424
+ loss_name, loss_fn, X_val_np, y_val_np,
2425
+ path_coefs[path_idx], float(path_intercepts[path_idx]),
2426
+ True, sample_weight=sw_val,
2427
+ )
2428
+ all_scores[fold_idx, sort_idx[alpha_idx_sorted]] = val_loss
2429
+ for attr in ("_cv_alpha_path", "_cv_path_results", "_cv_cache", "_preserve_cv_cache"):
2430
+ if hasattr(model, attr): delattr(model, attr)
2431
+ return
2432
+ else:
2433
+ # path is None — cleanup LLA state only; _cv_cache and
2434
+ # _preserve_cv_cache are still needed for the warm-start fallback.
2435
+ for attr in ("_cv_alpha_path", "_cv_path_results"):
2436
+ if hasattr(model, attr): delattr(model, attr)
2437
+ except Exception:
2438
+ # Same as path-is-None: keep _cv_cache for warm-start fallback.
2439
+ for attr in ("_cv_alpha_path", "_cv_path_results"):
2440
+ if hasattr(model, attr): delattr(model, attr)
2441
+
2442
+ # Warm-started alpha loop: fit per alpha, collect coefs for batch eval
2443
+ prev_coef = None
2444
+ prev_intercept = None
2445
+ fitted_coefs = [] # (alpha_idx_sorted, coef_np, intercept)
2446
+ for alpha_idx_sorted, alpha in enumerate(alpha_sorted):
2447
+ try:
2448
+ if cv_cache is not None:
2449
+ model._cv_cache = cv_cache
2450
+ model.alpha = alpha
2451
+ if hasattr(model, "_penalty") and model._penalty is not None:
2452
+ model._penalty.alpha = alpha
2453
+ if use_warm_start and prev_coef is not None:
2454
+ model._init_coef = np.asarray(prev_coef, dtype=np.float64)
2455
+ model._init_intercept = prev_intercept
2456
+ else:
2457
+ model._init_coef = None
2458
+ model._init_intercept = None
2459
+ model.fit(X_train_fit, y_train_fit, sample_weight=sw_train_fit)
2460
+
2461
+ coef_np = _to_numpy(model.coef_).ravel()
2462
+ intercept = float(model.intercept_)
2463
+ fitted_coefs.append((alpha_idx_sorted, coef_np.copy(), intercept))
2464
+ prev_coef = coef_np.copy()
2465
+ prev_intercept = intercept
2466
+ except Exception as exc:
2467
+ orig_idx = sort_idx[alpha_idx_sorted]
2468
+ all_scores[fold_idx, orig_idx] = np.nan
2469
+ logger.warning(
2470
+ "CV fold %d, alpha_idx %d (alpha=%.6g) fit failed: %s",
2471
+ fold_idx, orig_idx, alpha_sorted[alpha_idx_sorted], exc,
2472
+ )
2473
+
2474
+ # Batch validation: one GEMM for all fitted alphas
2475
+ # Pre-build loss-specific params
2476
+ _loss_params = {}
2477
+ if loss_name == "negative_binomial":
2478
+ _loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
2479
+ elif loss_name == "tweedie":
2480
+ _loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
2481
+
2482
+ if fitted_coefs:
2483
+ idxs = np.array([fc[0] for fc in fitted_coefs])
2484
+ coef_mat = np.column_stack([fc[1] for fc in fitted_coefs]) # (n_features, n_fitted)
2485
+ intercepts = np.array([fc[2] for fc in fitted_coefs]) # (n_fitted,)
2486
+ eta_mat = X_val_np @ coef_mat + intercepts[np.newaxis, :] # (n_val, n_fitted)
2487
+
2488
+ # Evaluate loss per alpha
2489
+ sw = np.asarray(_to_numpy(sw_val), dtype=np.float64).ravel() if sw_val is not None else None
2490
+ per_sample_loss = None
2491
+
2492
+ if loss_name == "squared_error":
2493
+ # Direct batch computation: squared residual
2494
+ per_sample_loss = (y_val_np[:, np.newaxis] - eta_mat) ** 2
2495
+ else:
2496
+ # GLM losses: use registry
2497
+ entry = _LOSS_EVAL_DISPATCH.get(loss_name)
2498
+ if entry is not None:
2499
+ per_sample_fn, _ = entry
2500
+ per_sample_loss = per_sample_fn(eta_mat, y_val_np[:, np.newaxis], **_loss_params)
2501
+
2502
+ if per_sample_loss is not None:
2503
+ if sw is not None:
2504
+ w_sum = float(np.sum(sw))
2505
+ if w_sum > 0:
2506
+ scores_fitted = np.sum(sw[:, np.newaxis] * per_sample_loss, axis=0) / w_sum
2507
+ else:
2508
+ scores_fitted = np.mean(per_sample_loss, axis=0)
2509
+ else:
2510
+ scores_fitted = np.mean(per_sample_loss, axis=0)
2511
+ for i, alpha_idx_sorted in enumerate(idxs):
2512
+ orig_idx = sort_idx[alpha_idx_sorted]
2513
+ all_scores[fold_idx, orig_idx] = float(scores_fitted[i])
2514
+
2515
+ if hasattr(model, "_cv_cache"): del model._cv_cache
2516
+ if hasattr(model, "_preserve_cv_cache"): del model._preserve_cv_cache
2517
+
2518
+ def _build_cv_cache(self, loss_name, device_name, X_train, y_train, sw_train):
2519
+ """Precompute XtX/Xty for squared-error GPU cache. Returns (cache_dict, L_np)."""
2520
+ if loss_name != "squared_error" or device_name not in ("cuda", "torch"):
2521
+ return None, None
2522
+ X_train_np = _to_numpy(X_train).astype(np.float64)
2523
+ y_train_np = _to_numpy(y_train).astype(np.float64).ravel()
2524
+ n_tr, _ = X_train_np.shape
2525
+ sw_np = _to_numpy(sw_train).astype(np.float64).ravel() if sw_train is not None else None
2526
+ if sw_np is not None:
2527
+ w_sum = float(sw_np.sum())
2528
+ X_mean_np = np.average(X_train_np, axis=0, weights=sw_np)
2529
+ y_mean_np = float(np.average(y_train_np, weights=sw_np))
2530
+ Xc_np = X_train_np - X_mean_np
2531
+ yc_np = y_train_np - y_mean_np
2532
+ sqrt_w = np.sqrt(sw_np)
2533
+ W_Xc = Xc_np * sqrt_w[:, None]
2534
+ XtX_np = W_Xc.T @ W_Xc
2535
+ Xty_np = (Xc_np * sw_np[:, None]).T @ yc_np
2536
+ L_np = float(_max_eigval_power(XtX_np)) / max(w_sum, 1.0)
2537
+ n_effective = w_sum
2538
+ else:
2539
+ X_mean_np = np.mean(X_train_np, axis=0)
2540
+ y_mean_np = np.mean(y_train_np)
2541
+ Xc_np = X_train_np - X_mean_np
2542
+ yc_np = y_train_np - y_mean_np
2543
+ XtX_np = Xc_np.T @ Xc_np
2544
+ Xty_np = Xc_np.T @ yc_np
2545
+ L_np = float(_max_eigval_power(XtX_np)) / n_tr
2546
+ n_effective = float(n_tr)
2547
+ if device_name == "cuda":
2548
+ import cupy as cp
2549
+ cache = {"XtX": cp.asarray(XtX_np), "Xty": cp.asarray(Xty_np), "n_effective": n_effective}
2550
+ else:
2551
+ import torch
2552
+ _torch_dev = "cuda" if torch.cuda.is_available() else "cpu"
2553
+ cache = {"XtX": torch.as_tensor(XtX_np, device=_torch_dev, dtype=torch.float64),
2554
+ "Xty": torch.as_tensor(Xty_np, device=_torch_dev, dtype=torch.float64),
2555
+ "n_effective": n_effective}
2556
+ return cache, L_np
2557
+
2558
+ def fit(self, X, y, sample_weight=None):
2559
+ """Fit the CV model with optimized strict or explicit two-stage CV."""
2560
+ # Normalize array-like inputs (lists, tuples, etc.) to arrays
2561
+ if not hasattr(X, 'shape'):
2562
+ X = np.asarray(X, dtype=np.float64)
2563
+ if not hasattr(y, 'shape'):
2564
+ y = np.asarray(y, dtype=np.float64)
2565
+
2566
+ if self._alpha_grid_input is not None:
2567
+ alpha_grid = np.asarray(self._alpha_grid_input, dtype=np.float64)
2568
+ else:
2569
+ alpha_grid = self._generate_alpha_grid(X, y)
2570
+ alpha_grid = np.asarray(alpha_grid, dtype=np.float64).ravel()
2571
+
2572
+ self.alpha_grid_ = alpha_grid
2573
+ n_samples = X.shape[0]
2574
+ n_alphas = len(alpha_grid)
2575
+ penalty_name = str(self.penalty).lower()
2576
+ cv_device = self._effective_cv_device(X, penalty_name, n_alphas)
2577
+ cv_solver = self._solver_for_cv(cv_device, X=X)
2578
+ self.cv_strategy_ = self.cv_strategy
2579
+ self.cv_selected_device_ = _device_to_name(cv_device)
2580
+
2581
+ if self.cv_splits is not None:
2582
+ # Normalize to list (generators would exhaust on first pass)
2583
+ folds = list(self.cv_splits) if not isinstance(self.cv_splits, list) else self.cv_splits
2584
+ else:
2585
+ folds = kfold_indices(n_samples, self.cv, self.random_state)
2586
+ all_scores_stage1 = None
2587
+ mean_scores_stage1 = None
2588
+ refined_mask = np.ones(n_alphas, dtype=bool)
2589
+
2590
+ if self.cv_strategy == "two_stage":
2591
+ if not self.acknowledge_approx:
2592
+ warnings.warn(
2593
+ "PenalizedGLM_CV(cv_strategy='two_stage') uses relaxed CV "
2594
+ "solves to screen the alpha grid before strict refinement. "
2595
+ "The final refit still uses the original max_iter and tol. "
2596
+ "Pass acknowledge_approx=True to silence this warning.",
2597
+ ApproximateCVWarning,
2598
+ stacklevel=2,
2599
+ )
2600
+ stage1_max_iter = min(int(self.max_iter), max(50, int(self.max_iter) // 4))
2601
+ stage1_tol = max(float(self.tol) * 10.0, 1e-4)
2602
+ all_scores_stage1 = self._compute_cv_scores(
2603
+ X,
2604
+ y,
2605
+ alpha_grid,
2606
+ cv_device,
2607
+ folds,
2608
+ sample_weight=sample_weight,
2609
+ max_iter=stage1_max_iter,
2610
+ tol=stage1_tol,
2611
+ strict=False,
2612
+ )
2613
+ mean_scores_stage1 = np.nanmean(all_scores_stage1, axis=0)
2614
+ refined_mask = _two_stage_candidate_mask(
2615
+ mean_scores_stage1,
2616
+ refine_top_k=self.refine_top_k,
2617
+ )
2618
+ if self.loss == "squared_error" and penalty_name in ("scad", "mcp"):
2619
+ refined_mask[:] = True
2620
+ if not np.any(refined_mask):
2621
+ refined_mask[:] = True
2622
+
2623
+ refined_alpha_grid = alpha_grid[refined_mask]
2624
+ refined_scores = self._compute_cv_scores(
2625
+ X,
2626
+ y,
2627
+ refined_alpha_grid,
2628
+ cv_device,
2629
+ folds,
2630
+ sample_weight=sample_weight,
2631
+ max_iter=self.max_iter,
2632
+ tol=self.tol,
2633
+ strict=True,
2634
+ )
2635
+ all_scores = np.array(all_scores_stage1, copy=True)
2636
+ all_scores[:, refined_mask] = refined_scores
2637
+ mean_scores = np.nanmean(all_scores, axis=0)
2638
+ refined_mean = np.nanmean(refined_scores, axis=0)
2639
+ refined_best = self._best_index_from_scores(
2640
+ refined_mean,
2641
+ refined_alpha_grid,
2642
+ cv_solver,
2643
+ )
2644
+ best_idx = int(np.flatnonzero(refined_mask)[refined_best])
2645
+ else:
2646
+ all_scores = self._compute_cv_scores(
2647
+ X,
2648
+ y,
2649
+ alpha_grid,
2650
+ cv_device,
2651
+ folds,
2652
+ sample_weight=sample_weight,
2653
+ max_iter=self.max_iter,
2654
+ tol=self.tol,
2655
+ strict=True,
2656
+ )
2657
+ mean_scores = np.nanmean(all_scores, axis=0)
2658
+ best_idx = self._best_index_from_scores(mean_scores, alpha_grid, cv_solver)
2659
+
2660
+ best_alpha = float(alpha_grid[best_idx])
2661
+ self.alpha_ = best_alpha
2662
+ # sklearn convention: best_score_ is negative loss (higher is better)
2663
+ self.best_score_ = -float(mean_scores[best_idx])
2664
+ self.cv_results_ = {
2665
+ "alpha": alpha_grid,
2666
+ "mean_score": mean_scores,
2667
+ "all_scores": all_scores,
2668
+ "cv_strategy_": self.cv_strategy_,
2669
+ "cv_selected_device_": self.cv_selected_device_,
2670
+ "mean_score_stage1": mean_scores_stage1,
2671
+ "all_scores_stage1": all_scores_stage1,
2672
+ "refined_mask": refined_mask,
2673
+ }
2674
+
2675
+ self.estimator_ = self._refit_best(X, y, best_alpha, sample_weight=sample_weight)
2676
+ self.coef_ = self.estimator_.coef_
2677
+ self.intercept_ = self.estimator_.intercept_
2678
+
2679
+ self._fitted = True
2680
+ return self
2681
+
2682
+ def predict(self, X):
2683
+ """Predict using the refit estimator with the best alpha."""
2684
+ if not getattr(self, '_fitted', False):
2685
+ raise RuntimeError("PenalizedGLM_CV is not fitted yet. Call fit() first.")
2686
+ return self.estimator_.predict(X)
2687
+
2688
+ def score(self, X, y, sample_weight=None):
2689
+ """Return the score on the given data.
2690
+
2691
+ For squared_error loss, returns R². For GLM losses, returns
2692
+ the deviance-based pseudo-R² (1 - deviance/null_deviance).
2693
+
2694
+ Note: ``best_score_`` is negative CV loss (sklearn convention),
2695
+ while ``score()`` returns R² or accuracy. These are different metrics.
2696
+ """
2697
+ if not getattr(self, '_fitted', False):
2698
+ raise RuntimeError("PenalizedGLM_CV is not fitted yet. Call fit() first.")
2699
+ return self.estimator_.score(X, y, sample_weight=sample_weight)