statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1877 @@
1
+ """Fit mixin for PenalizedGeneralizedLinearModel."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ from typing import TYPE_CHECKING
7
+
8
+ from statgpu._config import Device
9
+ from statgpu.backends import get_backend, _get_torch_device_str, _to_numpy, _LINALG_ERRORS
10
+ from statgpu.solvers._utils import _nesterov_momentum, _nesterov_update
11
+
12
+ if TYPE_CHECKING:
13
+ from ._base import PenalizedGeneralizedLinearModel as _Self
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Solver dispatch table for solver='auto'
17
+ # ---------------------------------------------------------------------------
18
+ # Each entry is (solver, condition_fn). First match wins.
19
+ # condition_fn takes (loss, penalty, backend, l1_ratio, cv_mode, problem_size).
20
+
21
+ # Import shared penalty categories (single source of truth)
22
+ from statgpu.penalties._categories import (
23
+ NONCONVEX as _NONCONVEX_PENALTIES,
24
+ SPARSE as _SPARSE_PENALTIES,
25
+ )
26
+ _SMOOTH_PENALTIES = frozenset({"l2", "none", "null", ""})
27
+
28
+ # (solver, condition)
29
+ # condition = (loss, penalty, backend, l1_ratio, cv_mode, problem_size) -> bool
30
+ _SOLVER_DISPATCH_TABLE = [
31
+ # -- Priority 1: Exact closed-form solutions (highest priority) --
32
+ # Ridge + squared_error has an exact eigendecomposition solver.
33
+ ("exact", lambda l, p, b, lr, cv, ps: l == "squared_error" and p == "l2"),
34
+
35
+ # -- Priority 2: Nonconvex penalties always use FISTA+LLA wrapper --
36
+ # SCAD/MCP/adaptive_l1 require iteratively reweighted L1 (LLA approximation).
37
+ ("fista", lambda l, p, b, lr, cv, ps: p in _NONCONVEX_PENALTIES),
38
+
39
+ # -- Priority 3: Squared error + sparse penalties -> FISTA --
40
+ # Quadratic loss + L1/ElasticNet: FISTA with exact line search.
41
+ ("fista", lambda l, p, b, lr, cv, ps: l == "squared_error" and p in _SPARSE_PENALTIES),
42
+
43
+ # -- Priority 4: GLM + GPU + sparse penalties (size-gated) --
44
+ # Poisson + GPU + L1: fista_bb for small/medium problems (< 2M elements).
45
+ ("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "poisson" and b in ("cupy", "torch") and p == "l1" and (ps is None or ps < 2_000_000)),
46
+ # Poisson + GPU + ElasticNet: fista_bb (BB step adapts well to EN geometry).
47
+ ("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "poisson" and b in ("cupy", "torch") and p in ("elasticnet", "en")),
48
+ # Poisson + CPU + sparse: FISTA (CPU backtracking is cheap).
49
+ ("fista", lambda l, p, b, lr, cv, ps: cv and l == "poisson" and p in _SPARSE_PENALTIES),
50
+
51
+ # -- Priority 5: NB + GPU + sparse penalties --
52
+ # NB + GPU + L1: fista_bb (NB gradient is well-behaved for BB steps).
53
+ ("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "negative_binomial" and b in ("cupy", "torch") and p == "l1"),
54
+ # NB + GPU + ElasticNet: FISTA for medium problems (200K-1M), fista_bb otherwise.
55
+ ("fista", lambda l, p, b, lr, cv, ps: cv and l == "negative_binomial" and b in ("cupy", "torch") and p in ("elasticnet", "en") and ps is not None and 200_000 <= ps < 1_000_000),
56
+ ("fista_bb", lambda l, p, b, lr, cv, ps: cv and l == "negative_binomial" and b in ("cupy", "torch") and p in ("elasticnet", "en")),
57
+
58
+ # -- Priority 6: Gamma/IG/Tweedie + sparse -> FISTA --
59
+ # These families have steep loss landscapes; FISTA with backtracking is safer.
60
+ ("fista", lambda l, p, b, lr, cv, ps: l in ("gamma", "inverse_gaussian") and p in _SPARSE_PENALTIES),
61
+ ("fista", lambda l, p, b, lr, cv, ps: l == "tweedie" and b in ("cupy", "torch") and p in _SPARSE_PENALTIES),
62
+
63
+ # -- Priority 7: Logistic + sparse -> FISTA --
64
+ # Logistic has iterate-dependent Lipschitz; FISTA with fixed global bound.
65
+ ("fista", lambda l, p, b, lr, cv, ps: cv and l == "logistic" and p in _SPARSE_PENALTIES),
66
+
67
+ # -- Priority 8: Default sparse -> fista_bb --
68
+ # Catch-all for remaining sparse penalty cases.
69
+ ("fista_bb", lambda l, p, b, lr, cv, ps: p in _SPARSE_PENALTIES),
70
+
71
+ # -- Priority 9: CV + L2: loss-specific smooth solvers --
72
+ # NB needs L-BFGS (non-canonical link issues with IRLS).
73
+ ("lbfgs", lambda l, p, b, lr, cv, ps: cv and p == "l2" and l == "negative_binomial"),
74
+ # Poisson/Tweedie: Newton (canonical link, well-conditioned).
75
+ ("newton", lambda l, p, b, lr, cv, ps: cv and p == "l2" and l in ("poisson", "tweedie")),
76
+ # Gamma/IG: L-BFGS (non-canonical link, better convergence).
77
+ ("lbfgs", lambda l, p, b, lr, cv, ps: cv and p == "l2" and l in ("gamma", "inverse_gaussian")),
78
+
79
+ # -- Priority 10: Smooth penalties (L2/none) with loss-specific solvers --
80
+ ("newton", lambda l, p, b, lr, cv, ps: p in _SMOOTH_PENALTIES and l in ("gamma", "tweedie", "inverse_gaussian")),
81
+ ("irls", lambda l, p, b, lr, cv, ps: p in _SMOOTH_PENALTIES and l in ("logistic", "poisson", "negative_binomial")),
82
+ ]
83
+
84
+
85
+ def _preferred_penalized_glm_solver(
86
+ loss_name,
87
+ penalty_name,
88
+ backend_name=None,
89
+ l1_ratio=0.5,
90
+ cv_mode=False,
91
+ problem_size=None,
92
+ ):
93
+ """Private benchmark-backed solver policy for solver='auto'.
94
+
95
+ This helper only chooses an internal solver. It must never be used to
96
+ override an explicitly requested solver or to change the selected device.
97
+
98
+ Dispatch is table-driven: first matching rule wins.
99
+ """
100
+ loss_name = str(loss_name or "").lower()
101
+ penalty_name = str(penalty_name or "").lower()
102
+ backend_name = str(backend_name or "").lower()
103
+ if problem_size is not None:
104
+ problem_size = int(problem_size)
105
+
106
+ for solver, cond in _SOLVER_DISPATCH_TABLE:
107
+ if cond(loss_name, penalty_name, backend_name, l1_ratio, cv_mode, problem_size):
108
+ return solver
109
+
110
+ return "fista"
111
+
112
+
113
+ def _resolve_loss_name(loss_name, loss_kwargs=None):
114
+ """Resolve loss name string to loss object via the GLM loss registry."""
115
+ from statgpu.glm_core._base import get_glm_loss
116
+ loss_kwargs = loss_kwargs or {}
117
+ return get_glm_loss(loss_name, **loss_kwargs)
118
+
119
+
120
+ def _irls_ridge_init(X, y, loss_name, alpha=0.01, max_iter=100, tol=1e-4, loss_kwargs=None):
121
+ """Compute ridge-penalized GLM coefficients for adaptive_l1 init.
122
+
123
+ For squared_error uses IRLS-CD (matching R glmnet's ridge solver).
124
+ For GLM losses (logistic, poisson, etc.) uses FISTA with L2 penalty,
125
+ which has proper line search and handles extreme y values robustly.
126
+
127
+ Parameters
128
+ ----------
129
+ X : ndarray of shape (n, p)
130
+ Feature matrix (no intercept column).
131
+ y : ndarray of shape (n,)
132
+ Response vector.
133
+ loss_name : str
134
+ GLM loss name: 'logistic', 'poisson', 'squared_error', etc.
135
+ alpha : float
136
+ Ridge penalty strength (lambda in R glmnet).
137
+ max_iter : int
138
+ Maximum IRLS iterations.
139
+ tol : float
140
+ Convergence tolerance on coefficient change.
141
+
142
+ Returns
143
+ -------
144
+ coef : ndarray of shape (p,)
145
+ Ridge-penalized coefficient estimates (no intercept).
146
+ """
147
+ if loss_name in ("squared_error", ""):
148
+ coef = _irls_ridge_init_cd(X, y, alpha, max_iter, tol)
149
+ else:
150
+ # For GLM losses, use FISTA with L2 penalty (robust line search)
151
+ # Pass arrays directly — solver handles backend detection internally
152
+ from statgpu.solvers import fista_solver
153
+ from statgpu.penalties import get_penalty
154
+ l2_pen = get_penalty("l2", alpha=alpha)
155
+ loss_obj = _resolve_loss_name(loss_name, loss_kwargs=loss_kwargs)
156
+ coef, _ = fista_solver(loss_obj, l2_pen, X, y, max_iter=max_iter, tol=tol)
157
+ # Return as numpy array (caller expects numpy for penalty.set_weights)
158
+ from statgpu.backends import _to_numpy
159
+ return np.asarray(_to_numpy(coef), dtype=np.float64)
160
+
161
+
162
+ def _irls_ridge_init_cd(X, y, alpha, max_iter, tol):
163
+ """Ridge regression initialization for adaptive L1 weights.
164
+
165
+ Uses closed-form solution: beta = (X'X + alpha*I)^-1 X'y
166
+ which is O(p^3) but fully parallelizable on GPU (single matmul + solve).
167
+ Much faster than sequential coordinate descent on GPU.
168
+ """
169
+ from statgpu.backends import _resolve_backend
170
+ from statgpu.backends._utils import _get_xp
171
+
172
+ backend = _resolve_backend("auto", X)
173
+ xp = _get_xp(backend)
174
+
175
+ n, p = X.shape
176
+ # Normalize features
177
+ feat_norms = xp.sqrt(xp.sum(X ** 2, axis=0))
178
+ if backend == "torch":
179
+ import torch
180
+ feat_norms = xp.maximum(feat_norms, torch.tensor(1e-20, dtype=feat_norms.dtype, device=feat_norms.device))
181
+ scale = torch.tensor(float(n) ** 0.5, dtype=X.dtype, device=X.device) / feat_norms
182
+ else:
183
+ feat_norms = xp.maximum(feat_norms, 1e-20)
184
+ scale = xp.asarray(float(n) ** 0.5, dtype=X.dtype) / feat_norms
185
+ X_work = X * scale
186
+
187
+ # Closed-form Ridge: (X'X + alpha*I)^-1 X'y
188
+ XtX = X_work.T @ X_work / n
189
+ Xty = X_work.T @ y / n
190
+
191
+ if backend == "torch":
192
+ import torch
193
+ I_mat = torch.eye(p, dtype=X.dtype, device=X.device)
194
+ beta = torch.linalg.solve(XtX + alpha * I_mat, Xty)
195
+ elif backend == "cupy":
196
+ import cupy as cp
197
+ I_mat = cp.eye(p, dtype=X.dtype)
198
+ beta = cp.linalg.solve(XtX + alpha * I_mat, Xty)
199
+ else:
200
+ I_mat = np.eye(p, dtype=X.dtype)
201
+ beta = np.linalg.solve(XtX + alpha * I_mat, Xty)
202
+
203
+ return beta * scale
204
+
205
+
206
+ class _PenalizedFitMixin:
207
+
208
+ def fit(self, X=None, y=None, sample_weight=None, formula=None, data=None):
209
+ """
210
+ Fit penalized GLM model.
211
+
212
+ Parameters
213
+ ----------
214
+ X : array-like of shape (n_samples, n_features), optional
215
+ Training data. Required when ``formula`` is None.
216
+ y : array-like of shape (n_samples,), optional
217
+ Target values. Required when ``formula`` is None.
218
+ sample_weight : array-like of shape (n_samples,), optional
219
+ Sample weights.
220
+ formula : str, optional
221
+ R-style formula string, e.g. ``"y ~ x1 + C(group)"``.
222
+ data : pandas.DataFrame, optional
223
+ Data used to evaluate ``formula``.
224
+
225
+ Returns
226
+ -------
227
+ self : PenalizedLinearRegression
228
+ Fitted estimator.
229
+ """
230
+ if formula is not None:
231
+ if data is None:
232
+ raise ValueError(
233
+ "formula was provided but data is None. "
234
+ "Pass data=your_dataframe when using formula."
235
+ )
236
+ from statgpu.core.formula import FormulaParser
237
+
238
+ parser = FormulaParser(formula)
239
+ y, X, design_info = parser.eval(data)
240
+ formula_column_names = list(design_info.column_names)
241
+ self._design_info = design_info
242
+ self._formula_has_intercept = "Intercept" in formula_column_names
243
+ self._feature_names = [name for name in formula_column_names if name != "Intercept"]
244
+ if self._formula_has_intercept:
245
+ X = np.delete(X, formula_column_names.index("Intercept"), axis=1)
246
+ self._use_intercept = True
247
+ else:
248
+ # Formula syntax owns intercept semantics, matching statsmodels/R.
249
+ self._use_intercept = False
250
+ else:
251
+ if X is None or y is None:
252
+ raise ValueError("Either formula+data or X+y must be provided.")
253
+ self._feature_names = None
254
+ self._design_info = None
255
+ self._formula_has_intercept = None
256
+ self._use_intercept = None
257
+
258
+ # Record number of features for sklearn compatibility
259
+ if X is not None:
260
+ X_arr = np.asarray(X) if not hasattr(X, 'shape') else X
261
+ self.n_features_in_ = X_arr.shape[1] if X_arr.ndim >= 2 else 1
262
+
263
+ self._penalty = self._resolve_penalty()
264
+ self._validate_solver_penalty()
265
+ self._loss = self._resolve_loss()
266
+ self._validate_inference_request()
267
+ self._inference_precomputed = False
268
+ self._precomputed_gaussian_state = None
269
+ self._clear_inference_state()
270
+
271
+ # Resolve the actual backend before auto-selecting the solver. This
272
+ # keeps solver="auto" device-aware: CPU can use IRLS for smooth GLMs,
273
+ # while GPU/Torch stays on accelerator-capable FISTA.
274
+ backend = self._get_backend(backend="auto")
275
+ backend_name = backend.name
276
+
277
+ # Auto-dispatch small problems to CPU only when device="auto".
278
+ # Explicit CUDA/TORCH device selection must never silently fall back.
279
+ if self.device == Device.AUTO and backend_name in ("cupy", "torch") and X is not None:
280
+ _n, _p = X.shape
281
+ if _n * _p < 200_000:
282
+ backend_name = "numpy"
283
+
284
+ backend_name = self._auto_backend_override(backend_name, X)
285
+ selected_solver = self._select_solver(
286
+ self._loss, backend_name=backend_name, X=X
287
+ )
288
+ self._selected_solver = selected_solver
289
+ self._selected_backend_name = backend_name
290
+
291
+ # Handle penalties requiring initialization (e.g., Adaptive Lasso)
292
+ if self._penalty.requires_init:
293
+ init_coef = self._fit_initial(X, y, backend_name=backend_name)
294
+ self._penalty.set_weights(init_coef)
295
+
296
+ # Non-convex penalties (SCAD, MCP) for squared_error: use IRLS-CD
297
+ # directly with a 100-step continuation path from lambda_max.
298
+ # This matches R ncvreg's algorithm for Gaussian regression.
299
+ # GLM+SCAD/MCP must NOT use IRLS-CD -- it cycles due to non-convex
300
+ # penalty causing features to flip on/off between IRLS iterations.
301
+ # GLM+SCAD/MCP goes through _fit_lla -> FISTA with proximal operator.
302
+ _pen_name = str(getattr(self._penalty, 'name', '')).lower()
303
+ _loss_name = str(getattr(self._loss, 'name', '') if hasattr(self, '_loss') else self.loss).lower()
304
+ _is_glm_loss = _loss_name not in ("squared_error", "")
305
+ if _pen_name in ("scad", "mcp") and self._lla_enabled and not _is_glm_loss:
306
+ # Use fused FISTA+LLA path for all backends (CPU/GPU).
307
+ from statgpu.solvers import fista_lla_path
308
+ self._nobs = X.shape[0]
309
+ X_arr = self._to_array(X, backend=backend_name)
310
+ y_arr = self._to_array(y, backend=backend_name)
311
+ # Lambda_max computation uses numpy (one-time cost, negligible).
312
+ _X_np = _to_numpy(X_arr)
313
+ _y_np = _to_numpy(y_arr)
314
+ _n = _X_np.shape[0]
315
+ _col_norms = np.sqrt(np.sum(_X_np ** 2, axis=0))
316
+ _col_norms = np.maximum(_col_norms, 1e-20)
317
+ _X_s = _X_np * (np.sqrt(_n) / _col_norms)
318
+ _y_c = _y_np - np.mean(_y_np)
319
+ _lam_max = float(np.max(np.abs(_X_s.T @ _y_c / _n)))
320
+ _target_alpha = float(self._penalty.alpha)
321
+ _n_cont = 20
322
+ _alpha_start = max(_lam_max, _target_alpha * 1.1)
323
+ if (not np.isfinite(_alpha_start)) or _alpha_start <= 0.0 or _target_alpha <= 0.0:
324
+ _alpha_path = np.linspace(max(_lam_max, 0.0), _target_alpha, _n_cont)
325
+ else:
326
+ _alpha_path = np.geomspace(_alpha_start, _target_alpha, _n_cont)
327
+ _max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // _n_cont)
328
+ _saved_mi = self.max_iter
329
+ _mi_path = []
330
+ for _i in range(_n_cont):
331
+ _is_last = (_i == _n_cont - 1)
332
+ _mi_path.append(_saved_mi if _is_last else max(100, _saved_mi // 10))
333
+ coef_np, intercept, n_iter = fista_lla_path(
334
+ self._loss, self._penalty,
335
+ X_arr, y_arr,
336
+ alpha_path=_alpha_path,
337
+ max_lla_per_step=_max_lla_per_step,
338
+ lla_tol=getattr(self, '_lla_tol', 1e-6),
339
+ max_iter=_mi_path,
340
+ tol=self.tol,
341
+ fit_intercept=self._effective_intercept,
342
+ sample_weight=sample_weight,
343
+ )
344
+ self.coef_ = coef_np
345
+ self.intercept_ = intercept
346
+ self.n_iter_ = n_iter
347
+ if self._effective_intercept:
348
+ self._params = np.concatenate([[self.intercept_], np.asarray(self.coef_)])
349
+ else:
350
+ self._params = np.asarray(self.coef_).copy()
351
+ self._df_resid = X.shape[0] - (X.shape[1] + (1 if self._effective_intercept else 0))
352
+ self._compute_post_fit_gaussian_inference(X, y, sample_weight=sample_weight)
353
+ if backend_name == "cupy":
354
+ self._cleanup_cuda_memory()
355
+ elif backend_name == "torch":
356
+ self._cleanup_torch_memory()
357
+ self._fitted = True
358
+ return self
359
+
360
+ X_arr = self._to_array(X, backend=backend_name)
361
+ y_arr = self._to_array(y, backend=backend_name)
362
+
363
+ if backend_name == "torch":
364
+ self._fit_torch(X_arr, y_arr, sample_weight)
365
+ elif backend_name == "cupy":
366
+ self._fit_gpu(X_arr, y_arr, sample_weight)
367
+ else:
368
+ self._fit_cpu(X_arr, y_arr, sample_weight)
369
+
370
+ self._compute_post_fit_gaussian_inference(X, y, sample_weight=sample_weight)
371
+ self._fitted = True
372
+ # Clean up CV cache unless a caller is intentionally reusing one
373
+ # across repeated fits, as PenalizedGLM_CV does within a fold.
374
+ if hasattr(self, '_cv_cache') and not getattr(self, '_preserve_cv_cache', False):
375
+ del self._cv_cache
376
+ return self
377
+
378
+ def _select_solver(self, loss, backend_name=None, X=None):
379
+ """Auto-select solver based on loss, penalty, and backend."""
380
+ if self.solver != "auto":
381
+ return self.solver
382
+ return _preferred_penalized_glm_solver(
383
+ getattr(loss, "name", self.loss),
384
+ getattr(self._penalty, "name", self.penalty),
385
+ backend_name=backend_name,
386
+ l1_ratio=getattr(self._penalty, "l1_ratio", self.l1_ratio),
387
+ cv_mode=False,
388
+ problem_size=None if X is None else int(X.shape[0]) * int(X.shape[1]),
389
+ )
390
+
391
+ @staticmethod
392
+ def _torch_cuda_available():
393
+ try:
394
+ import torch
395
+ return torch.cuda.is_available()
396
+ except Exception:
397
+ return False
398
+
399
+ @staticmethod
400
+ def _cupy_available():
401
+ try:
402
+ import cupy as cp
403
+ return cp.cuda.runtime.getDeviceCount() > 0
404
+ except Exception:
405
+ return False
406
+
407
+ # Backend override rules for device='auto' at large scale (problem_size >= 1M).
408
+ # Each entry: (loss, penalties, target_backend, reason_template)
409
+ # First match wins. target_backend="numpy" means always CPU;
410
+ # target_backend="torch" means prefer torch over cupy.
411
+ _AUTO_BACKEND_CPU_OVERRIDES = [
412
+ ("squared_error", ("l2",), "numpy", "large squared-error exact solve is faster on CPU"),
413
+ ("squared_error", ("l1", "elasticnet", "en"), "numpy", "large squared-error l1/elasticnet is faster on CPU"),
414
+ ("negative_binomial", ("l1", "elasticnet", "en"), "numpy", "large negative-binomial l1/elasticnet is faster on CPU"),
415
+ ("logistic", ("l1", "elasticnet", "en"), "numpy", "large logistic {penalty} is faster on CPU"),
416
+ ("gamma", ("l2",), "numpy", "large gamma l2/newton is faster on CPU"),
417
+ ("tweedie", ("l1", "elasticnet", "en"), "numpy", "large tweedie {penalty} is faster on CPU"),
418
+ ]
419
+ _AUTO_BACKEND_CUPY_OVERRIDES = [
420
+ ("negative_binomial", ("l2",), "torch", "large negative-binomial l2 is faster on {target} than cupy"),
421
+ ("logistic", ("l1", "elasticnet", "en"), "torch", "large logistic {penalty} is faster on {target} than cupy"),
422
+ ("poisson", ("l1", "elasticnet", "en"), "torch", "large poisson {penalty} is faster on {target} than cupy"),
423
+ ]
424
+
425
+ def _auto_backend_override(self, backend_name, X):
426
+ """Benchmark-backed backend routing for device='auto' only."""
427
+ self._auto_backend_reason = None
428
+ if self.device != Device.AUTO or self.solver != "auto" or X is None:
429
+ return backend_name
430
+
431
+ n_samples, n_features = X.shape
432
+ problem_size = int(n_samples) * int(n_features)
433
+ if problem_size < 1_000_000:
434
+ return backend_name
435
+
436
+ loss_name = str(getattr(self._loss, "name", self.loss)).lower()
437
+ penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
438
+ torch_ok = self._torch_cuda_available()
439
+
440
+ # CPU overrides: always route to numpy
441
+ for loss, penalties, target, reason_tpl in self._AUTO_BACKEND_CPU_OVERRIDES:
442
+ if loss_name == loss and penalty_name in penalties:
443
+ self._auto_backend_reason = reason_tpl.format(penalty=penalty_name)
444
+ return target
445
+
446
+ # CuPy->Torch overrides: prefer torch when available, else CPU
447
+ if backend_name == "cupy":
448
+ for loss, penalties, target, reason_tpl in self._AUTO_BACKEND_CUPY_OVERRIDES:
449
+ if loss_name == loss and penalty_name in penalties:
450
+ if torch_ok:
451
+ self._auto_backend_reason = reason_tpl.format(
452
+ penalty=penalty_name, target="torch")
453
+ return "torch"
454
+ self._auto_backend_reason = reason_tpl.format(
455
+ penalty=penalty_name, target="CPU")
456
+ return "numpy"
457
+
458
+ return backend_name
459
+
460
+ def _fit_initial(self, X, y, backend_name="numpy"):
461
+ """Fit initial model for penalties requiring initialization.
462
+
463
+ Parameters
464
+ ----------
465
+ X : array
466
+ Design matrix.
467
+ y : array
468
+ Target vector.
469
+ backend_name : str
470
+ Backend to use ('numpy', 'torch', 'cupy'). Default 'numpy'.
471
+
472
+ Uses OLS when n_samples > n_features (well-determined, unbiased),
473
+ and Ridge otherwise (works for any p, required when p > n).
474
+
475
+ The ``init_method`` on the penalty controls which path is taken:
476
+ - 'auto': OLS if n > p, Ridge otherwise
477
+ - 'ols': forced OLS (raises if p > n)
478
+ - 'ridge': forced Ridge (always works)
479
+
480
+ OLS is only safe for squared_error (Gaussian) data. For GLM losses
481
+ (Poisson, logistic, etc.) OLS can produce extreme coefficients whose
482
+ Lipschitz constant is enormous, causing the inner FISTA solver to
483
+ take zero-length steps and exit immediately without moving.
484
+
485
+ For GLM losses we use sparse L1 initialization only for non-convex
486
+ penalties (SCAD, MCP) that will enter the LLA outer loop -- a sparse
487
+ seed gives LLA differentiated weights and drives genuine sparsity.
488
+ Convex penalties with ``requires_init=True`` (adaptive_l1) need a
489
+ dense seed because their weights are 1/|coef| -- zero entries from
490
+ L1 init become permanently frozen."""
491
+ n_samples, n_features = X.shape
492
+ init_method = getattr(self._penalty, "init_method", "auto")
493
+ _is_glm = getattr(self, 'loss', 'squared_error') != "squared_error"
494
+ _is_nonconvex = not getattr(self._penalty, "is_convex", True)
495
+
496
+ if not _is_glm and not self._penalty.requires_init and (
497
+ init_method == "ols" or (init_method == "auto" and n_samples > n_features)
498
+ ):
499
+ ols_coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
500
+ return ols_coef
501
+
502
+ if _is_glm and _is_nonconvex:
503
+ # Dense l2-penalized GLM init for non-convex penalties (SCAD, MCP).
504
+ # With the corrected lla_weights (= P'(|coef|), not P'(|coef|)/|coef|),
505
+ # a dense starting point lets the LLA continuation path push small
506
+ # coefficients through the transition region where SCAD and MCP
507
+ # differ, matching the path-based strategy used by R's ncvreg.
508
+ from statgpu.penalties import get_penalty
509
+ from statgpu.solvers import fista_solver
510
+
511
+ l2_pen = get_penalty("l2", alpha=0.001)
512
+ loss_obj = self._resolve_loss()
513
+ # Use matching backend for GPU data
514
+ if backend_name in ("torch", "cupy"):
515
+ backend = get_backend(backend=backend_name, device='cuda')
516
+ X_b = backend.asarray(X, dtype=backend.float64)
517
+ y_b = backend.asarray(y, dtype=backend.float64)
518
+ else:
519
+ X_b = np.asarray(_to_numpy(X), dtype=np.float64)
520
+ y_b = np.asarray(_to_numpy(y), dtype=np.float64)
521
+ init_coef, _ = fista_solver(
522
+ loss_obj, l2_pen, X_b, y_b,
523
+ max_iter=500, tol=1e-4,
524
+ )
525
+ return init_coef
526
+
527
+ if self._penalty.requires_init:
528
+ # adaptive_l1: weights = 1/(|init_coef|+eps)^nu, so init must
529
+ # produce well-scaled coefficients. Use IRLS with coordinate
530
+ # descent (matching R glmnet's ridge solver) instead of FISTA,
531
+ # which converges more tightly and gives larger coefficients
532
+ # -> smaller weights -> too many features surviving.
533
+ loss_name = getattr(self, 'loss', 'squared_error')
534
+ # Use matching backend for GPU data
535
+ if backend_name in ("torch", "cupy"):
536
+ backend = get_backend(backend=backend_name, device='cuda')
537
+ X_b = backend.asarray(X, dtype=backend.float64)
538
+ y_b = backend.asarray(y, dtype=backend.float64)
539
+ else:
540
+ X_b = np.asarray(_to_numpy(X), dtype=np.float64)
541
+ y_b = np.asarray(_to_numpy(y), dtype=np.float64)
542
+ init_coef = _irls_ridge_init(
543
+ X_b, y_b,
544
+ loss_name=loss_name,
545
+ alpha=0.01,
546
+ max_iter=100,
547
+ tol=1e-4,
548
+ loss_kwargs=getattr(self, "loss_kwargs", None),
549
+ )
550
+ return init_coef
551
+
552
+ from statgpu.linear_model.wrappers._ridge import Ridge
553
+
554
+ init_model = Ridge(
555
+ alpha=0.1,
556
+ fit_intercept=self._effective_intercept,
557
+ device=self.device,
558
+ )
559
+ init_model.fit(X, y)
560
+ return init_model.coef_
561
+
562
+ def _fit_cpu(self, X, y, sample_weight=None):
563
+ """Fit using CPU (FISTA or coordinate descent)."""
564
+ X = np.asarray(X)
565
+ y = np.asarray(y)
566
+
567
+ n_samples, n_features = X.shape
568
+ self._nobs = n_samples
569
+
570
+ # Route to loss-aware solver for non-squared_error loss
571
+ solver_name = self._selected_solver or self._select_solver(
572
+ self._loss, backend_name="numpy"
573
+ )
574
+ if self.loss != "squared_error" or solver_name in ("irls", "newton", "lbfgs", "admm"):
575
+ if solver_name == "irls":
576
+ self._fit_irls_backend(X, y, sample_weight, "numpy")
577
+ else:
578
+ self._fit_loss_backend(X, y, sample_weight, solver_name, "numpy")
579
+ return
580
+
581
+ # Route squared_error + SCAD/MCP/adaptive_l1/group_lasso/elasticnet
582
+ # through _fit_loss_backend so CPU and GPU paths produce identical results.
583
+ _cd_penalties_for_sqerr = ("scad", "mcp", "adaptive_l1", "adaptive_lasso", "group_lasso")
584
+ if getattr(self._penalty, 'name', '') in _cd_penalties_for_sqerr:
585
+ self._fit_loss_backend(X, y, sample_weight, solver_name, "numpy")
586
+ return
587
+
588
+ # Original squared-error path (backward compatible)
589
+
590
+ if sample_weight is not None:
591
+ sample_weight = np.asarray(sample_weight)
592
+ sqrt_sw = np.sqrt(sample_weight)
593
+ X = X * sqrt_sw[:, np.newaxis]
594
+ y = y * sqrt_sw
595
+
596
+ pen = self._penalty
597
+
598
+ if self._effective_intercept:
599
+ X_mean = np.mean(X, axis=0)
600
+ y_mean = np.mean(y)
601
+ X_centered = X - X_mean
602
+ y_centered = y - y_mean
603
+ else:
604
+ X_centered = X
605
+ y_mean = 0.0
606
+ y_centered = y
607
+
608
+ if y_centered.ndim == 1:
609
+ y_centered = y_centered.reshape(-1, 1)
610
+
611
+ # Precompute for gradient (use CV cache if available)
612
+ _cv = getattr(self, '_cv_cache', None)
613
+ if _cv is not None and 'XtX' in _cv:
614
+ XtX = _cv['XtX']
615
+ Xty = _cv['Xty']
616
+ else:
617
+ XtX = X_centered.T @ X_centered
618
+ Xty = X_centered.T @ y_centered.flatten()
619
+
620
+ pen = self._penalty
621
+ if solver_name == "exact":
622
+ if pen.name != "l2":
623
+ raise ValueError("solver='exact' is only supported for L2/Ridge penalty.")
624
+ self.coef_ = self._solve_exact_numpy(XtX, Xty, n_samples)
625
+ self.n_iter_ = 1
626
+ if self._effective_intercept:
627
+ self.intercept_ = float(y_mean - X_mean @ self.coef_)
628
+ self._params = np.concatenate([[self.intercept_], self.coef_])
629
+ else:
630
+ self.intercept_ = 0.0
631
+ self._params = self.coef_.copy()
632
+ self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
633
+ return
634
+
635
+ # Lipschitz constant: L = lambda_max(XtX) / n
636
+ if self.lipschitz_L is not None:
637
+ L = float(self.lipschitz_L)
638
+ else:
639
+ from statgpu.backends._array_ops import _max_eigval_power
640
+ L = _max_eigval_power(XtX) / n_samples
641
+
642
+ if L <= 0:
643
+ self.coef_ = np.zeros(n_features)
644
+ self.n_iter_ = 0
645
+ else:
646
+ step = 1.0 / L
647
+
648
+ _cd_penalties = ("adaptive_l1", "adaptive_lasso", "scad", "mcp", "group_lasso")
649
+ if solver_name in ("fista_bb", "fista") and pen.name not in _cd_penalties:
650
+ # FISTA with XtX precomputation.
651
+ # BB step (fista_bb) provides no benefit for quadratic losses
652
+ # (BB1=BB2=1/R_H(dw)), so both use the fixed Lipschitz step.
653
+ if hasattr(self, '_init_coef') and self._init_coef is not None:
654
+ coef = np.asarray(self._init_coef, dtype=np.float64).copy()
655
+ else:
656
+ coef = np.zeros(n_features)
657
+ y_k = coef.copy()
658
+ t_k = 1.0
659
+
660
+ for iteration in range(self.max_iter):
661
+ coef_old = coef.copy()
662
+
663
+ grad_at_y = (XtX @ y_k - Xty) / n_samples
664
+ w_tilde = y_k - step * grad_at_y
665
+ coef = pen.proximal(w_tilde, step, backend="numpy")
666
+
667
+ # Scheduled momentum restart
668
+ if iteration > 0 and iteration % 50 == 0:
669
+ t_k = 1.0
670
+
671
+ # Nesterov momentum
672
+ y_k, t_k = _nesterov_update(coef, coef_old, t_k)
673
+
674
+ self.n_iter_ = iteration + 1
675
+
676
+ if np.sum(np.abs(coef - coef_old)) < self.tol:
677
+ break
678
+
679
+ else:
680
+ # Coordinate descent (for L1-type penalties)
681
+ X_sq_norms = np.diag(XtX)
682
+ if hasattr(self, '_init_coef') and self._init_coef is not None:
683
+ coef = np.asarray(self._init_coef, dtype=np.float64).copy()
684
+ else:
685
+ coef = np.zeros(n_features)
686
+
687
+ # Precompute per-coordinate thresholds for adaptive penalties.
688
+ # The penalty object stores mean-normalized weights (w = pf / mean(pf))
689
+ # and _norm_factor = mean(pf). The CD threshold per coordinate is
690
+ # alpha * w_j * n, matching R glmnet's lambda * pf_j * n / X_j'X_j
691
+ # after dividing by X_sq_norms[j].
692
+ _adaptive_thresh = None
693
+ if pen.name in ("adaptive_l1", "adaptive_lasso"):
694
+ _w = np.asarray(getattr(pen, '_weights', np.ones(n_features)), dtype=float)
695
+ _adaptive_thresh = self.alpha * _w * n_samples
696
+
697
+ # Precompute SCAD/MCP constants (hoisted out of inner loop)
698
+ _a_scad = float(getattr(pen, 'a', 3.7)) if pen.name == "scad" else 0.0
699
+ _gamma_mcp = float(getattr(pen, 'gamma', 3.0)) if pen.name == "mcp" else 0.0
700
+
701
+ # Precompute group info for group_lasso block CD
702
+ _is_group = pen.name == "group_lasso"
703
+ if _is_group:
704
+ _g_indices = getattr(pen, '_group_indices', None)
705
+ _sqrt_pg = getattr(pen, '_sqrt_pg', None)
706
+ if _g_indices is None or _sqrt_pg is None:
707
+ raise ValueError(
708
+ "group_lasso penalty must have groups set. "
709
+ "Pass groups=... in penalty_kwargs."
710
+ )
711
+ _n_groups = len(_g_indices)
712
+ # Precompute XtX blocks per group: XtX[g_idx][:, g_idx]
713
+ _XtX_blocks = []
714
+ for g_idx in _g_indices:
715
+ _XtX_blocks.append(XtX[np.ix_(g_idx, g_idx)])
716
+
717
+ for iteration in range(self.max_iter):
718
+ coef_old = coef.copy()
719
+
720
+ if _is_group:
721
+ # Block coordinate descent: iterate over groups
722
+ for g in range(_n_groups):
723
+ g_idx = _g_indices[g]
724
+ # Group partial residual:
725
+ # rho_g = Xty[g] - XtX[g,:] @ coef + XtX[g,g] @ coef[g]
726
+ rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + _XtX_blocks[g] @ coef[g_idx]
727
+ # Unpenalized group update: w_g = (X'X)_gg^{-1} @ rho_g
728
+ try:
729
+ w_g = np.linalg.solve(_XtX_blocks[g], rho_g)
730
+ except np.linalg.LinAlgError:
731
+ w_g = np.zeros(len(g_idx))
732
+ # Block soft-thresholding
733
+ norm_w = np.linalg.norm(w_g)
734
+ thresh_g = self.alpha * _sqrt_pg[g]
735
+ if norm_w > thresh_g:
736
+ coef[g_idx] = w_g * (1.0 - thresh_g / norm_w)
737
+ else:
738
+ coef[g_idx] = 0.0
739
+ else:
740
+ # Per-coordinate CD for L1-type penalties
741
+ for j in range(n_features):
742
+ rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
743
+
744
+ if pen.name in ("adaptive_l1", "adaptive_lasso"):
745
+ thresh = _adaptive_thresh[j]
746
+ if X_sq_norms[j] > 1e-10:
747
+ coef[j] = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0) / X_sq_norms[j]
748
+ else:
749
+ coef[j] = 0.0
750
+ elif pen.name == "l1":
751
+ # Soft thresholding
752
+ thresh = self.alpha * n_samples
753
+ if X_sq_norms[j] > 1e-10:
754
+ coef[j] = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0) / X_sq_norms[j]
755
+ else:
756
+ coef[j] = 0.0
757
+ elif pen.name == "elasticnet":
758
+ # Elastic net CD matching both sklearn and R glmnet:
759
+ # beta_j = S(rho_j, alpha*l1_ratio*n) / (X_j'X_j + alpha*(1-l1_ratio)*n)
760
+ thresh = self.alpha * self.l1_ratio * n_samples
761
+ if X_sq_norms[j] > 1e-10:
762
+ st = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0)
763
+ coef[j] = st / (X_sq_norms[j] + self.alpha * (1 - self.l1_ratio) * n_samples)
764
+ else:
765
+ coef[j] = 0.0
766
+ elif pen.name == "scad":
767
+ # SCAD CD matching R ncvreg: threshold = alpha * n
768
+ # Guard: a_scad must be > 1 and != 2 to avoid div/0.
769
+ a_scad = max(float(_a_scad), 1.0 + 1e-6)
770
+ if abs(a_scad - 2.0) < 1e-6:
771
+ a_scad = 2.0 + 1e-6
772
+ if X_sq_norms[j] > 1e-10:
773
+ w_j = rho_j / X_sq_norms[j]
774
+ aw = np.abs(w_j)
775
+ lam = self.alpha * n_samples
776
+ if aw > a_scad * lam:
777
+ coef[j] = w_j
778
+ elif aw > lam:
779
+ coef[j] = np.sign(w_j) * ((a_scad - 1.0) * aw - a_scad * lam) / (a_scad - 2.0)
780
+ else:
781
+ coef[j] = 0.0
782
+ else:
783
+ coef[j] = 0.0
784
+ elif pen.name == "mcp":
785
+ # MCP CD matching R ncvreg: threshold = alpha * n
786
+ # Guard: gamma_mcp must be > 1 to avoid div/0.
787
+ gamma_mcp = max(float(_gamma_mcp), 1.0 + 1e-6)
788
+ if X_sq_norms[j] > 1e-10:
789
+ w_j = rho_j / X_sq_norms[j]
790
+ aw = np.abs(w_j)
791
+ lam = self.alpha * n_samples
792
+ if aw > gamma_mcp * lam:
793
+ coef[j] = w_j
794
+ elif aw > lam:
795
+ coef[j] = np.sign(w_j) * (aw - lam) / (1.0 - 1.0 / gamma_mcp)
796
+ else:
797
+ coef[j] = 0.0
798
+ else:
799
+ coef[j] = 0.0
800
+ else:
801
+ raise NotImplementedError(
802
+ f"Coordinate descent not implemented for "
803
+ f"penalty '{pen.name}'. Use solver='fista'."
804
+ )
805
+
806
+ self.n_iter_ = iteration + 1
807
+
808
+ if np.sum(np.abs(coef - coef_old)) < self.tol:
809
+ break
810
+
811
+ # Compute intercept and store results
812
+ if L > 0:
813
+ self.coef_ = coef
814
+
815
+ if self._effective_intercept:
816
+ self.intercept_ = float(y_mean - X_mean @ self.coef_)
817
+ self._params = np.concatenate([[self.intercept_], self.coef_])
818
+ else:
819
+ self.intercept_ = 0.0
820
+ self._params = self.coef_.copy()
821
+
822
+ self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
823
+
824
+ def _fit_gpu(self, X, y, sample_weight=None):
825
+ """Fit using GPU (CuPy) with FISTA."""
826
+ self._fit_gpu_backend(X, y, sample_weight, backend_name="cupy")
827
+
828
+ def _fit_torch(self, X, y, sample_weight=None):
829
+ """Fit using Torch GPU with FISTA. Delegates to unified backend."""
830
+ self._fit_gpu_backend(X, y, sample_weight, backend_name="torch")
831
+
832
+ # ------------------------------------------------------------------
833
+ # Unified GPU backend (replaces _fit_gpu + _fit_torch)
834
+ # ------------------------------------------------------------------
835
+
836
+ @staticmethod
837
+ def _soft_threshold_gpu(w, thresh, xp):
838
+ """Backend-agnostic soft-thresholding on GPU."""
839
+ if xp.__name__ == "torch":
840
+ import torch
841
+ return torch.sign(w) * torch.relu(torch.abs(w) - thresh)
842
+ return xp.sign(w) * xp.maximum(xp.abs(w) - thresh, 0.0)
843
+
844
+ def _fit_gpu_backend(self, X, y, sample_weight=None, backend_name="cupy"):
845
+ """Unified GPU fit method for both CuPy and Torch backends.
846
+
847
+ Handles exact (L2), FISTA, and FISTA-BE solvers with inline
848
+ XtX precomputation and fused element-wise kernels.
849
+ """
850
+ from statgpu.backends._utils import _get_xp, xp_asarray, xp_zeros, xp_copy, xp_ones
851
+ from statgpu.backends import _to_numpy
852
+ from statgpu.backends._array_ops import _abs_sum_dev
853
+
854
+ xp = _get_xp(backend_name)
855
+ is_torch = (backend_name == "torch")
856
+
857
+ solver_name = self._selected_solver or self._select_solver(
858
+ self._loss, backend_name=backend_name
859
+ )
860
+ _backend_label = "Torch" if is_torch else "CuPy"
861
+ if solver_name not in ("fista", "fista_bb", "admm", "auto", "exact", "irls", "newton", "lbfgs"):
862
+ raise ValueError(
863
+ f"{_backend_label} backend supports solver='fista', 'fista_bb', 'admm', "
864
+ f"'exact', 'irls', 'newton', and 'lbfgs', got '{solver_name}'."
865
+ )
866
+
867
+ n_samples, n_features = X.shape
868
+ self._nobs = n_samples
869
+
870
+ # --- Exact solver (closed-form Ridge) ---
871
+ if solver_name == "exact":
872
+ if self._penalty.name != "l2":
873
+ raise ValueError("solver='exact' is only supported for L2/Ridge penalty.")
874
+ X = xp_asarray(X, dtype=np.float64, xp=xp, ref_arr=X)
875
+ y = xp_asarray(y, dtype=np.float64, xp=xp, ref_arr=y)
876
+ if is_torch:
877
+ import torch
878
+ if X.dtype != torch.float64:
879
+ X = X.to(torch.float64)
880
+ if sample_weight is not None:
881
+ sw = xp_asarray(sample_weight, dtype=X.dtype, xp=xp, ref_arr=X)
882
+ sqrt_sw = xp.sqrt(sw)
883
+ X = X * sqrt_sw[:, None]
884
+ y = y * sqrt_sw
885
+ if self._effective_intercept:
886
+ X_mean = xp.mean(X, axis=0)
887
+ y_mean = xp.mean(y)
888
+ X_centered = X - X_mean
889
+ y_centered = y - y_mean
890
+ else:
891
+ X_centered = X
892
+ y_mean = xp_zeros((), X.dtype, xp, ref_arr=X) if is_torch else xp.array(0.0, dtype=X.dtype)
893
+ y_centered = y
894
+ if y_centered.ndim == 1:
895
+ y_centered = y_centered.reshape(-1)
896
+ _cv = getattr(self, '_cv_cache', None)
897
+ if _cv is not None and 'XtX' in _cv:
898
+ XtX = _cv['XtX']
899
+ Xty = _cv['Xty']
900
+ else:
901
+ XtX = X_centered.T @ X_centered
902
+ Xty = X_centered.T @ y_centered
903
+
904
+ # Dispatch to backend-specific exact solver
905
+ solve_fn = getattr(self, f'_solve_exact_{"torch" if is_torch else "cupy"}')
906
+ coef = solve_fn(XtX, Xty, n_samples)
907
+ self.n_iter_ = 1
908
+ if self.compute_inference:
909
+ infer_fn = getattr(self, f'_precompute_exact_l2_inference_{"torch" if is_torch else "cupy"}')
910
+ if self._effective_intercept:
911
+ intercept_gpu = (y_mean.reshape(1) - X_mean.reshape(1, -1) @ coef.reshape(-1, 1)).reshape(-1)
912
+ coef_full_gpu = xp.concatenate([intercept_gpu, coef.reshape(-1)])
913
+ infer_fn(X, y, XtX, X_mean, coef_full_gpu.reshape(-1), n_samples)
914
+ else:
915
+ infer_fn(X, y, XtX, None, coef.reshape(-1), n_samples)
916
+ coef_np = _to_numpy(coef)
917
+ if self._effective_intercept:
918
+ self.intercept_ = float(_to_numpy(y_mean) - _to_numpy(X_mean) @ coef_np)
919
+ self.coef_ = coef_np
920
+ self._params = np.concatenate([[self.intercept_], self.coef_])
921
+ else:
922
+ self.intercept_ = 0.0
923
+ self.coef_ = coef_np
924
+ self._params = coef_np.copy()
925
+ self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
926
+ if is_torch:
927
+ self._cleanup_torch_memory()
928
+ else:
929
+ self._cleanup_cuda_memory()
930
+ return
931
+
932
+ # Route IRLS/newton/lbfgs through their dedicated backends.
933
+ if solver_name in ("irls", "newton", "lbfgs"):
934
+ if solver_name == "irls":
935
+ self._fit_irls_backend(X, y, sample_weight, backend_name)
936
+ else:
937
+ self._fit_loss_backend(X, y, sample_weight, solver_name, backend_name)
938
+ return
939
+
940
+ # Route non-L1 and non-squared-error through the generic loss backend.
941
+ if self.loss != "squared_error" or solver_name == "admm" or self._penalty.name not in ("l1", "elasticnet", "en"):
942
+ self._fit_loss_backend(X, y, sample_weight, solver_name, backend_name)
943
+ return
944
+
945
+ # --- Inline FISTA fast-path for L1 + squared_error ---
946
+ X = xp_asarray(X, dtype=np.float64, xp=xp, ref_arr=X)
947
+ y = xp_asarray(y, dtype=np.float64, xp=xp, ref_arr=y)
948
+ if is_torch:
949
+ import torch
950
+ if X.dtype != torch.float64:
951
+ X = X.to(torch.float64)
952
+
953
+ if sample_weight is not None:
954
+ sample_weight = xp_asarray(sample_weight, dtype=X.dtype, xp=xp, ref_arr=X)
955
+ sqrt_sw = xp.sqrt(sample_weight)
956
+ X = X * sqrt_sw[:, None]
957
+ y = y * sqrt_sw
958
+
959
+ if self._effective_intercept:
960
+ X_mean = xp.mean(X, axis=0)
961
+ y_mean = xp.mean(y)
962
+ X_centered = X - X_mean
963
+ y_centered = y - y_mean
964
+ else:
965
+ X_centered = X
966
+ y_mean = xp_zeros((), X.dtype, xp, ref_arr=X) if is_torch else xp.array(0.0, dtype=X.dtype)
967
+ y_centered = y
968
+
969
+ if y_centered.ndim == 1:
970
+ y_centered = y_centered.reshape(-1)
971
+
972
+ _cv = getattr(self, '_cv_cache', None)
973
+ if _cv is not None and 'XtX' in _cv:
974
+ XtX = _cv['XtX']
975
+ Xty = _cv['Xty']
976
+ else:
977
+ XtX = X_centered.T @ X_centered
978
+ Xty = X_centered.T @ y_centered
979
+
980
+ # Lipschitz constant: L = lambda_max(XtX) / n
981
+ if self.lipschitz_L is not None:
982
+ L = float(self.lipschitz_L)
983
+ else:
984
+ if n_features < 1000:
985
+ L = float(xp.linalg.eigvalsh(XtX)[-1]) / n_samples
986
+ else:
987
+ v = xp_ones(n_features, X.dtype, xp, ref_arr=X)
988
+ v = v / xp.linalg.norm(v)
989
+ for _ in range(50):
990
+ v_new = XtX @ v
991
+ v_norm = xp.linalg.norm(v_new)
992
+ if v_norm < 1e-15:
993
+ break
994
+ v = v_new / v_norm
995
+ L = float(_to_numpy(v @ (XtX @ v))) / n_samples
996
+
997
+ if L <= 0:
998
+ coef = xp_zeros(n_features, X.dtype, xp, ref_arr=X)
999
+ self.n_iter_ = 0
1000
+ elif solver_name in ("fista_bb", "fista"):
1001
+ step = 1.0 / L
1002
+ step_over_n = step / n_samples
1003
+ step_over_n_Xty = step_over_n * Xty
1004
+ if self._penalty.name in ("elasticnet", "en"):
1005
+ thresh = self.alpha * self._penalty.l1_ratio * step
1006
+ l2_scale = 1.0 + self.alpha * (1.0 - self._penalty.l1_ratio) * step
1007
+ else:
1008
+ thresh = self.alpha * step
1009
+ l2_scale = 1.0
1010
+ _use_l2 = abs(l2_scale - 1.0) > 1e-12
1011
+
1012
+ if hasattr(self, '_init_coef') and self._init_coef is not None:
1013
+ coef = xp_asarray(self._init_coef, dtype=X.dtype, xp=xp, ref_arr=X)
1014
+ else:
1015
+ coef = xp_zeros(n_features, X.dtype, xp, ref_arr=X)
1016
+ y_k = xp_copy(coef)
1017
+ t_k = 1.0
1018
+ beta = 0.0
1019
+
1020
+ # Build fused element-wise kernel (backend-specific JIT)
1021
+ _fused_step = None
1022
+ _fused_step_l2 = None
1023
+ _st_fn = self._soft_threshold_gpu
1024
+
1025
+ if is_torch:
1026
+ import torch
1027
+ if _use_l2:
1028
+ try:
1029
+ def _fista_elementwise_l2(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
1030
+ _thresh, _l2_scale, _coef_old, _beta):
1031
+ w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
1032
+ c = _st_fn(w, _thresh, xp) / _l2_scale
1033
+ y = c + _beta * (c - _coef_old)
1034
+ return c, y
1035
+ _fused_step_l2 = torch.compile(_fista_elementwise_l2, mode='reduce-overhead')
1036
+ except Exception:
1037
+ _fused_step_l2 = None
1038
+ else:
1039
+ try:
1040
+ def _fista_elementwise(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
1041
+ _thresh, _coef_old, _beta):
1042
+ w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
1043
+ c = _st_fn(w, _thresh, xp)
1044
+ y = c + _beta * (c - _coef_old)
1045
+ return c, y
1046
+ _fused_step = torch.compile(_fista_elementwise, mode='reduce-overhead')
1047
+ except Exception:
1048
+ _fused_step = None
1049
+ else:
1050
+ import cupy as cp
1051
+ if _use_l2:
1052
+ try:
1053
+ @cp.fuse()
1054
+ def _fista_elementwise_l2(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
1055
+ _thresh, _l2_scale, _coef_old, _beta):
1056
+ w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
1057
+ c = (cp.sign(w) * cp.maximum(cp.abs(w) - _thresh, 0.0) / _l2_scale)
1058
+ y = c + _beta * (c - _coef_old)
1059
+ return c, y
1060
+ _fused_step_l2 = _fista_elementwise_l2
1061
+ _dummy = cp.zeros(1, dtype=X.dtype)
1062
+ _fused_step_l2(_dummy, _dummy, _dummy, 0.0, 0.0, 1.0, _dummy, 0.0)
1063
+ except Exception:
1064
+ _fused_step_l2 = None
1065
+ else:
1066
+ try:
1067
+ @cp.fuse()
1068
+ def _fista_elementwise(_y_k, _xtx_y, _step_over_n_Xty, _step_over_n,
1069
+ _thresh, _coef_old, _beta):
1070
+ w = _y_k - _step_over_n * _xtx_y + _step_over_n_Xty
1071
+ c = (cp.sign(w) * cp.maximum(cp.abs(w) - _thresh, 0.0))
1072
+ y = c + _beta * (c - _coef_old)
1073
+ return c, y
1074
+ _fused_step = _fista_elementwise
1075
+ _dummy = cp.zeros(1, dtype=X.dtype)
1076
+ _fused_step(_dummy, _dummy, _dummy, 0.0, 0.0, _dummy, 0.0)
1077
+ except Exception:
1078
+ _fused_step = None
1079
+
1080
+ for iteration in range(self.max_iter):
1081
+ coef_old = xp_copy(coef)
1082
+ xtx_y = XtX @ y_k
1083
+
1084
+ if _use_l2:
1085
+ if _fused_step_l2 is not None:
1086
+ coef, y_k = _fused_step_l2(
1087
+ y_k, xtx_y, step_over_n_Xty, step_over_n,
1088
+ thresh, l2_scale, coef_old, beta,
1089
+ )
1090
+ else:
1091
+ w_tilde = y_k - step_over_n * xtx_y + step_over_n_Xty
1092
+ coef = _st_fn(w_tilde, thresh, xp) / l2_scale
1093
+ y_k = coef + beta * (coef - coef_old)
1094
+ else:
1095
+ if _fused_step is not None:
1096
+ coef, y_k = _fused_step(
1097
+ y_k, xtx_y, step_over_n_Xty, step_over_n,
1098
+ thresh, coef_old, beta,
1099
+ )
1100
+ else:
1101
+ w_tilde = y_k - step_over_n * xtx_y + step_over_n_Xty
1102
+ coef = _st_fn(w_tilde, thresh, xp)
1103
+ y_k = coef + beta * (coef - coef_old)
1104
+
1105
+ if iteration > 0 and iteration % 50 == 0:
1106
+ t_k = 1.0
1107
+
1108
+ beta, t_k = _nesterov_momentum(t_k)
1109
+
1110
+ self.n_iter_ = iteration + 1
1111
+ if iteration % 5 == 4 and float(_to_numpy(_abs_sum_dev(coef - coef_old))) < self.tol:
1112
+ break
1113
+ else:
1114
+ step = 1.0 / L
1115
+ if hasattr(self, '_init_coef') and self._init_coef is not None:
1116
+ coef = xp_asarray(self._init_coef, dtype=X.dtype, xp=xp, ref_arr=X)
1117
+ else:
1118
+ coef = xp_zeros(n_features, X.dtype, xp, ref_arr=X)
1119
+ y_k = xp_copy(coef)
1120
+ t_k = 1.0
1121
+
1122
+ for iteration in range(self.max_iter):
1123
+ coef_old = xp_copy(coef)
1124
+ grad = (XtX @ y_k - Xty) / n_samples
1125
+ w_tilde = y_k - step * grad
1126
+ coef = self._penalty.proximal(w_tilde, step, backend=backend_name)
1127
+
1128
+ if iteration > 0 and iteration % 50 == 0:
1129
+ t_k = 1.0
1130
+
1131
+ y_k, t_k = _nesterov_update(coef, coef_old, t_k)
1132
+
1133
+ self.n_iter_ = iteration + 1
1134
+ if iteration % 5 == 4 and float(_to_numpy(_abs_sum_dev(coef - coef_old))) < self.tol:
1135
+ break
1136
+
1137
+ # Transfer to CPU
1138
+ coef_np = _to_numpy(coef)
1139
+ if self._effective_intercept:
1140
+ self.intercept_ = float(_to_numpy(y_mean) - _to_numpy(X_mean) @ coef_np)
1141
+ self.coef_ = coef_np
1142
+ self._params = np.concatenate([[self.intercept_], self.coef_])
1143
+ else:
1144
+ self.intercept_ = 0.0
1145
+ self.coef_ = coef_np
1146
+ self._params = coef_np.copy()
1147
+
1148
+ self._df_resid = n_samples - (n_features + (1 if self._effective_intercept else 0))
1149
+
1150
+ # Debiased inference on GPU (before cleanup)
1151
+ if self.compute_inference and "debiased" in str(getattr(self, "inference_method", "")).lower():
1152
+ penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
1153
+ if penalty_name in ("l1", "elasticnet", "en"):
1154
+ infer_fn = getattr(self, f'_compute_inference_debiased_{"torch" if is_torch else "gpu"}')
1155
+ infer_fn(X, y, coef)
1156
+
1157
+ if is_torch:
1158
+ self._cleanup_torch_memory()
1159
+ else:
1160
+ self._cleanup_cuda_memory()
1161
+
1162
+ def _ridge_alpha_for_exact(self) -> float:
1163
+ """Return L2 alpha for the exact Ridge normal equations."""
1164
+ return float(getattr(self._penalty, "alpha", self.alpha))
1165
+
1166
+ def _solve_exact_numpy(self, XtX, Xty, n_samples):
1167
+ alpha = self._ridge_alpha_for_exact()
1168
+ p = XtX.shape[0]
1169
+ # Per-sample convention: XtX is unnormalized (X'X), so we need
1170
+ # n*alpha to match loss/n + alpha*||w||^2 used by all other paths.
1171
+ A = XtX + (float(n_samples) * alpha) * np.eye(p, dtype=XtX.dtype)
1172
+ try:
1173
+ return np.linalg.solve(A, Xty)
1174
+ except np.linalg.LinAlgError:
1175
+ return np.linalg.pinv(A) @ Xty
1176
+
1177
+ def _solve_exact_cupy(self, XtX, Xty, n_samples):
1178
+ import cupy as cp
1179
+ from cupyx.scipy.linalg import solve_triangular as cp_solve_triangular
1180
+
1181
+ alpha = self._ridge_alpha_for_exact()
1182
+ p = XtX.shape[0]
1183
+ A = XtX + (float(n_samples) * alpha) * cp.eye(p, dtype=XtX.dtype)
1184
+ try:
1185
+ # Cholesky + triangular solve is faster than general solve
1186
+ # for positive-definite matrices (Ridge penalty guarantees PD)
1187
+ L = cp.linalg.cholesky(A)
1188
+ tmp = cp_solve_triangular(L, Xty, lower=True)
1189
+ return cp_solve_triangular(L.T, tmp, lower=False)
1190
+ except _LINALG_ERRORS:
1191
+ try:
1192
+ return cp.linalg.solve(A, Xty)
1193
+ except _LINALG_ERRORS:
1194
+ return cp.linalg.pinv(A) @ Xty
1195
+
1196
+ def _solve_exact_torch(self, XtX, Xty, n_samples):
1197
+ import torch
1198
+
1199
+ alpha = self._ridge_alpha_for_exact()
1200
+ p = XtX.shape[0]
1201
+ A = XtX + (float(n_samples) * alpha) * torch.eye(
1202
+ p, dtype=XtX.dtype, device=XtX.device
1203
+ )
1204
+ try:
1205
+ # torch.linalg.solve is faster than Cholesky + solve_triangular
1206
+ # on PyTorch due to kernel launch overhead for small matrices
1207
+ return torch.linalg.solve(A, Xty)
1208
+ except RuntimeError:
1209
+ return torch.linalg.pinv(A) @ Xty
1210
+
1211
+ def _block_cd_group_lasso(self, pen, X_work, y_arr, init):
1212
+ """Block coordinate descent for group_lasso penalty.
1213
+
1214
+ Matches R grpreg's block CD algorithm: iterate over groups, compute
1215
+ partial residual per group, solve the group subproblem, apply block
1216
+ soft-thresholding.
1217
+ """
1218
+ import numpy as np
1219
+
1220
+ n, pp = X_work.shape
1221
+ p = pp - 1 if self._effective_intercept else pp
1222
+ alpha = self.alpha
1223
+
1224
+ _inner = getattr(self, '_penalty', pen)
1225
+ _g_indices = getattr(_inner, '_group_indices', None)
1226
+ _sqrt_pg = getattr(_inner, '_sqrt_pg', None)
1227
+ if _g_indices is None or _sqrt_pg is None:
1228
+ raise ValueError(
1229
+ "group_lasso penalty must have groups set. "
1230
+ "Pass groups=... in penalty_kwargs."
1231
+ )
1232
+ _n_groups = len(_g_indices)
1233
+
1234
+ XtX = X_work.T @ X_work / n
1235
+ Xty = (X_work.T @ y_arr.flatten()) / n
1236
+
1237
+ _XtX_blocks = []
1238
+ for g_idx in _g_indices:
1239
+ _XtX_blocks.append(XtX[np.ix_(g_idx, g_idx)])
1240
+
1241
+ if init is not None:
1242
+ coef = np.array(init, dtype=np.float64)
1243
+ else:
1244
+ coef = np.zeros(pp, dtype=np.float64)
1245
+
1246
+ iteration = -1 # ensure defined when max_iter=0
1247
+ for iteration in range(self.max_iter):
1248
+ coef_old = coef.copy()
1249
+
1250
+ for g in range(_n_groups):
1251
+ g_idx = _g_indices[g]
1252
+ rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + _XtX_blocks[g] @ coef[g_idx]
1253
+ try:
1254
+ w_g = np.linalg.solve(_XtX_blocks[g], rho_g)
1255
+ except np.linalg.LinAlgError:
1256
+ w_g = np.zeros(len(g_idx))
1257
+ norm_w = np.linalg.norm(w_g)
1258
+ thresh_g = alpha * _sqrt_pg[g]
1259
+ if norm_w > thresh_g:
1260
+ coef[g_idx] = w_g * (1.0 - thresh_g / norm_w)
1261
+ else:
1262
+ coef[g_idx] = 0.0
1263
+
1264
+ if self._effective_intercept:
1265
+ coef[pp - 1] = np.mean(y_arr - X_work[:, :p] @ coef[:p])
1266
+
1267
+ if np.max(np.abs(coef - coef_old)) < self.tol:
1268
+ break
1269
+
1270
+ n_iter = iteration + 1
1271
+
1272
+ if self._effective_intercept:
1273
+ beta = coef[:p]
1274
+ intercept = float(coef[p])
1275
+ else:
1276
+ beta = coef
1277
+ intercept = 0.0
1278
+
1279
+ return beta, intercept, n_iter
1280
+
1281
+ def _block_cd_group_lasso_gpu(self, pen, X_work, y_arr, init, backend_name):
1282
+ """GPU-native block coordinate descent for group_lasso penalty.
1283
+
1284
+ Same algorithm as _block_cd_group_lasso but keeps all arrays on GPU.
1285
+ Enforces float64 precision to avoid NaN from float32 conditioning issues.
1286
+ """
1287
+ from statgpu.backends._array_ops import _xp_copy, _xp_zeros, _xp_asarray, _xp_eye
1288
+ from statgpu.backends._utils import _get_xp, xp_astype
1289
+ xp = _get_xp(backend_name)
1290
+
1291
+ # Enforce float64 precision for numerical stability
1292
+ X_work = xp_astype(X_work, xp.float64, xp)
1293
+ y_arr = xp_astype(y_arr, xp.float64, xp)
1294
+
1295
+ n, pp = X_work.shape
1296
+ p = pp - 1 if self._effective_intercept else pp
1297
+ alpha = self.alpha
1298
+
1299
+ _inner = getattr(self, '_penalty', pen)
1300
+ _g_indices = getattr(_inner, '_group_indices', None)
1301
+ _sqrt_pg_np = getattr(_inner, '_sqrt_pg', None)
1302
+ if _g_indices is None or _sqrt_pg_np is None:
1303
+ raise ValueError(
1304
+ "group_lasso penalty must have groups set. "
1305
+ "Pass groups=... in penalty_kwargs."
1306
+ )
1307
+ _n_groups = len(_g_indices)
1308
+ _sqrt_pg = [float(s) for s in _sqrt_pg_np]
1309
+
1310
+ XtX = X_work.T @ X_work / n
1311
+ Xty = (X_work.T @ y_arr.flatten()) / n
1312
+
1313
+ # Pre-compute XtX blocks with diagonal ridge for conditioning
1314
+ from statgpu.backends._array_ops import _scalar_tensor
1315
+ _XtX_blocks = []
1316
+ _ridge = _scalar_tensor(1e-10, X_work)
1317
+ for g_idx in _g_indices:
1318
+ block = XtX[g_idx][:, g_idx]
1319
+ block = block + _ridge * _xp_eye(block.shape[0], block.dtype, block)
1320
+ _XtX_blocks.append(block)
1321
+
1322
+ if init is not None:
1323
+ if isinstance(init, np.ndarray):
1324
+ coef = _xp_asarray(init, X_work.dtype, X_work)
1325
+ else:
1326
+ coef = _xp_copy(init)
1327
+ else:
1328
+ coef = _xp_zeros(pp, X_work.dtype, X_work)
1329
+
1330
+ iteration = -1 # ensure defined when max_iter=0
1331
+ for iteration in range(self.max_iter):
1332
+ coef_old = _xp_copy(coef)
1333
+
1334
+ for g in range(_n_groups):
1335
+ g_idx = _g_indices[g]
1336
+ rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + _XtX_blocks[g] @ coef[g_idx]
1337
+ try:
1338
+ w_g = xp.linalg.solve(_XtX_blocks[g], rho_g)
1339
+ if xp.any(xp.isnan(w_g)) or xp.any(xp.isinf(w_g)):
1340
+ w_g = _xp_zeros(len(g_idx), X_work.dtype, X_work)
1341
+ except Exception:
1342
+ w_g = _xp_zeros(len(g_idx), X_work.dtype, X_work)
1343
+ norm_w = float(xp.linalg.norm(w_g))
1344
+ thresh_g = alpha * _sqrt_pg[g]
1345
+ if norm_w > thresh_g:
1346
+ coef[g_idx] = w_g * (1.0 - thresh_g / norm_w)
1347
+ else:
1348
+ coef[g_idx] = 0.0
1349
+
1350
+ if self._effective_intercept:
1351
+ coef[pp - 1] = float(xp.mean(y_arr - X_work[:, :p] @ coef[:p]))
1352
+
1353
+ _max_change = float(xp.max(xp.abs(coef - coef_old)))
1354
+ if _max_change < self.tol:
1355
+ break
1356
+
1357
+ n_iter = iteration + 1
1358
+
1359
+ if self._effective_intercept:
1360
+ beta = coef[:p]
1361
+ intercept = float(coef[p])
1362
+ else:
1363
+ beta = coef
1364
+ intercept = 0.0
1365
+
1366
+ return beta, intercept, n_iter
1367
+
1368
+ def _fit_loss_backend(self, X, y, sample_weight, solver_name, backend_name):
1369
+ """Fit GLMLoss + Penalty without changing the selected backend."""
1370
+ from statgpu.solvers import (
1371
+ fista_solver,
1372
+ fista_bb_solver,
1373
+ admm_solver,
1374
+ lbfgs_solver,
1375
+ newton_solver,
1376
+ )
1377
+
1378
+ # Convert to target backend with float64 precision for numerical stability
1379
+ from statgpu.backends._array_ops import _xp_asarray
1380
+ from statgpu.backends._utils import _get_xp
1381
+ _xp = _get_xp(backend_name)
1382
+ _ref = X if not isinstance(X, np.ndarray) else _xp.zeros(1, dtype=_xp.float64)
1383
+ X_arr = _xp_asarray(X, _xp.float64, _ref)
1384
+ y_arr = _xp_asarray(y, _xp.float64, X_arr)
1385
+ if self._effective_intercept:
1386
+ p = X_arr.shape[1]
1387
+ X_work = self._column_stack(
1388
+ [X_arr, self._ones(X_arr.shape[0], backend_name, X_arr)],
1389
+ backend_name,
1390
+ )
1391
+ pen = self._selective_penalty(p, backend_name)
1392
+ init = None
1393
+ if self._init_coef is not None:
1394
+ init_intercept = float(getattr(self, '_init_intercept', 0.0) or 0.0)
1395
+ init = np.append(self._init_coef, init_intercept)
1396
+ init = _xp_asarray(init, X_arr.dtype, X_arr)
1397
+ else:
1398
+ # Warm-start intercept for GLM losses (prevents divergence
1399
+ # of the unpenalized intercept toward -inf for zero-heavy data).
1400
+ _loss_name = getattr(self._loss, 'name', '')
1401
+ _y_mean = float(np.mean(_to_numpy(y_arr)))
1402
+ if _loss_name == "poisson":
1403
+ _int_init = np.log(max(_y_mean, 1e-3))
1404
+ elif _loss_name == "logistic":
1405
+ _y_mean_clipped = np.clip(_y_mean, 1e-3, 1.0 - 1e-3)
1406
+ _int_init = np.log(_y_mean_clipped / (1.0 - _y_mean_clipped))
1407
+ elif _loss_name in ("gamma", "inverse_gaussian", "negative_binomial", "tweedie"):
1408
+ # All use log link: intercept init = log(y_mean)
1409
+ _int_init = np.log(max(_y_mean, 1e-3))
1410
+ else:
1411
+ _int_init = _y_mean # identity link (squared_error)
1412
+ init = np.zeros(p + 1)
1413
+ init[-1] = _int_init
1414
+ init = _xp_asarray(init, X_arr.dtype, X_arr)
1415
+ else:
1416
+ p = X_arr.shape[1]
1417
+ X_work = X_arr
1418
+ pen = self._penalty
1419
+ init = None
1420
+ if self._init_coef is not None:
1421
+ init = np.asarray(self._init_coef, dtype=np.float64)
1422
+ init = _xp_asarray(init, X_arr.dtype, X_arr)
1423
+
1424
+ # SCAD/MCP and adaptive_l1 use IRLS-CD (matching R ncvreg's
1425
+ # per-coordinate algorithm). GLM+SCAD/MCP uses 1 CD sweep per
1426
+ # IRLS iteration to avoid cycling.
1427
+ _loss_name = getattr(self._loss, 'name', '')
1428
+ _pen_name = getattr(pen, 'name', '')
1429
+ # SelectivePenalty (intercept wrapper) has no name; fall back to
1430
+ # the original penalty's name so SCAD/MCP routing works.
1431
+ if not _pen_name:
1432
+ _pen_name = getattr(self._penalty, 'name', '')
1433
+ _is_glm_loss = _loss_name not in ("squared_error", "")
1434
+ # Routing:
1435
+ # adaptive_l1/adaptive_lasso -> FISTA (weighted L1 proximal, works
1436
+ # for both GLM and squared_error; avoids slow sequential CD)
1437
+ # squared_error + SCAD/MCP -> IRLS-CD (matching R ncvreg)
1438
+ # GLM + SCAD/MCP -> IRLS-CD (matching R ncvreg's IRLS+CD algorithm)
1439
+ _use_fista = _pen_name in ("adaptive_l1", "adaptive_lasso")
1440
+ _use_irls_cd = (
1441
+ (_pen_name in ("scad", "mcp") and not _is_glm_loss)
1442
+ )
1443
+ _use_lla_fista = (
1444
+ _pen_name in ("scad", "mcp") and _is_glm_loss
1445
+ )
1446
+ _use_lla_group = (
1447
+ _pen_name in ("group_mcp", "group_scad", "gmcp", "gscad") and _is_glm_loss
1448
+ )
1449
+
1450
+ if _use_fista:
1451
+ # FISTA for GLM+adaptive_l1 -- works on any backend.
1452
+ from statgpu.solvers import fista_solver
1453
+ params, n_iter = fista_solver(
1454
+ self._loss, pen, X_work, y_arr,
1455
+ max_iter=self.max_iter, tol=self.tol,
1456
+ init_coef=init, sample_weight=sample_weight,
1457
+ )
1458
+ elif _use_irls_cd:
1459
+ # squared_error + SCAD/MCP: use fused FISTA+LLA on all backends.
1460
+ # Produces identical results across CPU/GPU and avoids slow
1461
+ # sequential coordinate descent on GPU.
1462
+ from statgpu.solvers import fista_lla_path
1463
+ import numpy as _np
1464
+
1465
+ # Compute continuation path (lambda_max -> target alpha)
1466
+ _X_feat = _to_numpy(X_work[:, :p] if self._effective_intercept else X_work)
1467
+ _y_feat = _to_numpy(y_arr)
1468
+ _n = _X_feat.shape[0]
1469
+ _col_norms = _np.sqrt(_np.sum(_X_feat ** 2, axis=0))
1470
+ _col_norms = _np.maximum(_col_norms, 1e-20)
1471
+ _X_s = _X_feat * (_np.sqrt(_n) / _col_norms)
1472
+ _y_c = _y_feat - _np.mean(_y_feat)
1473
+ _lam_max = float(_np.max(_np.abs(_X_s.T @ _y_c / _n)))
1474
+ _target_alpha = float(getattr(self._penalty, 'alpha', self.alpha))
1475
+ _n_cont = 20
1476
+ _alpha_path = _np.geomspace(
1477
+ max(_lam_max, _target_alpha * 1.1), _target_alpha, _n_cont,
1478
+ )
1479
+ _max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // _n_cont)
1480
+ _saved_mi = self.max_iter
1481
+ _mi_path = []
1482
+ for _i in range(_n_cont):
1483
+ _is_last = (_i == _n_cont - 1)
1484
+ _mi_path.append(_saved_mi if _is_last else max(100, _saved_mi // 10))
1485
+
1486
+ X_orig = X_work[:, :p] if self._effective_intercept else X_work
1487
+ coef_np, intercept, n_iter = fista_lla_path(
1488
+ self._loss, self._penalty,
1489
+ X_orig, y_arr,
1490
+ alpha_path=_alpha_path,
1491
+ max_lla_per_step=_max_lla_per_step,
1492
+ lla_tol=getattr(self, '_lla_tol', 1e-6),
1493
+ max_iter=_mi_path,
1494
+ tol=self.tol,
1495
+ fit_intercept=self._effective_intercept,
1496
+ sample_weight=sample_weight,
1497
+ )
1498
+ if self._effective_intercept:
1499
+ params_np = np.concatenate([coef_np, [intercept]])
1500
+ else:
1501
+ params_np = coef_np
1502
+ params = params_np
1503
+ elif _use_lla_fista:
1504
+ # GLM + SCAD/MCP: use LLA outer loop + FISTA inner solve.
1505
+ from statgpu.solvers import fista_lla_path
1506
+ import numpy as _np
1507
+
1508
+ xp = get_backend(backend_name).xp
1509
+
1510
+ # lambda_max with backend-native arrays (no CPU-GPU transfer)
1511
+ X_feat = X_work[:, :p] if self._effective_intercept else X_work
1512
+ _n = X_feat.shape[0]
1513
+ _col_norms = xp.sqrt(xp.sum(X_feat ** 2, axis=0))
1514
+ if backend_name == "torch":
1515
+ import torch
1516
+ _col_norms = torch.clamp(_col_norms, min=1e-20)
1517
+ else:
1518
+ _col_norms = xp.maximum(_col_norms, 1e-20)
1519
+ X_s = X_feat * (float(_n) ** 0.5 / _col_norms)
1520
+ y_c = y_arr - xp.mean(y_arr)
1521
+ _lam_max = float(xp.max(xp.abs(X_s.T @ y_c / _n)))
1522
+ _cv_alpha_path = getattr(self, '_cv_alpha_path', None)
1523
+ _cv_return_path = _cv_alpha_path is not None
1524
+ if _cv_return_path:
1525
+ _targets = _np.asarray(_cv_alpha_path, dtype=float).ravel()
1526
+ _targets = _targets[_np.isfinite(_targets) & (_targets > 0.0)]
1527
+ if _targets.size == 0:
1528
+ _targets = _np.asarray([float(getattr(self._penalty, 'alpha', self.alpha))])
1529
+ _targets = _np.sort(_targets)[::-1]
1530
+ _target_alpha = float(_targets[-1])
1531
+ _alpha_start = max(_lam_max, float(_targets[0]) * 1.1)
1532
+ if _alpha_start > float(_targets[0]) * (1.0 + 1e-10):
1533
+ _alpha_path = _np.concatenate([[_alpha_start], _targets])
1534
+ else:
1535
+ _alpha_path = _targets
1536
+ _n_cont = int(_alpha_path.size)
1537
+ else:
1538
+ _target_alpha = float(getattr(self._penalty, 'alpha', self.alpha))
1539
+ _n_cont = 20
1540
+ _alpha_path = _np.geomspace(
1541
+ max(_lam_max, _target_alpha * 1.1), _target_alpha, _n_cont,
1542
+ )
1543
+
1544
+ _max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // max(_n_cont, 1))
1545
+ _saved_mi = self.max_iter
1546
+ if _cv_return_path:
1547
+ _mi_path = [max(200, _saved_mi // 2)] * max(_n_cont - 1, 0) + [_saved_mi]
1548
+ else:
1549
+ _mi_path = [_saved_mi if i == _n_cont - 1 else max(100, _saved_mi // 10)
1550
+ for i in range(_n_cont)]
1551
+
1552
+ X_orig = X_work[:, :p] if self._effective_intercept else X_work
1553
+
1554
+ _warm_coef = None
1555
+ _warm_intercept = None
1556
+ _init = getattr(self, '_init_coef', None)
1557
+ if _init is not None:
1558
+ _init_np = np.asarray(_to_numpy(_init), dtype=np.float64).ravel()
1559
+ if self._effective_intercept and _init_np.size == p + 1:
1560
+ _warm_coef = _init_np[:p]
1561
+ _warm_intercept = float(_init_np[p])
1562
+ elif _init_np.size == p:
1563
+ _warm_coef = _init_np
1564
+ if self._effective_intercept:
1565
+ _warm_intercept = float(
1566
+ getattr(self, '_init_intercept', 0.0) or 0.0
1567
+ )
1568
+
1569
+ _lla_result = fista_lla_path(
1570
+ self._loss, self._penalty,
1571
+ X_orig, y_arr,
1572
+ alpha_path=_alpha_path,
1573
+ max_lla_per_step=_max_lla_per_step,
1574
+ lla_tol=getattr(self, '_lla_tol', 1e-6),
1575
+ max_iter=_mi_path,
1576
+ tol=self.tol,
1577
+ fit_intercept=self._effective_intercept,
1578
+ sample_weight=sample_weight,
1579
+ init_coef=_warm_coef,
1580
+ init_intercept=_warm_intercept,
1581
+ return_path=_cv_return_path,
1582
+ )
1583
+ if _cv_return_path:
1584
+ coef_np, intercept, n_iter, _path_results = _lla_result
1585
+ self._cv_path_results = _path_results
1586
+ else:
1587
+ coef_np, intercept, n_iter = _lla_result
1588
+ # fista_lla_path returns numpy, convert back to backend-native
1589
+ if self._effective_intercept:
1590
+ params = xp.concatenate([xp.asarray(coef_np), xp.asarray([intercept])])
1591
+ else:
1592
+ params = xp.asarray(coef_np)
1593
+ elif _use_lla_group:
1594
+ # GLM + group_mcp/group_scad: LLA outer loop + FISTA inner solve
1595
+ # with AdaptiveGroupLassoPenalty as inner penalty.
1596
+ from statgpu.solvers import fista_lla_path
1597
+ from statgpu.penalties._group_lasso import AdaptiveGroupLassoPenalty
1598
+ import numpy as _np
1599
+
1600
+ xp = get_backend(backend_name).xp
1601
+
1602
+ # lambda_max with backend-native arrays
1603
+ X_feat = X_work[:, :p] if self._effective_intercept else X_work
1604
+ _n = X_feat.shape[0]
1605
+ _col_norms = xp.sqrt(xp.sum(X_feat ** 2, axis=0))
1606
+ if backend_name == "torch":
1607
+ import torch
1608
+ _col_norms = torch.clamp(_col_norms, min=1e-20)
1609
+ else:
1610
+ _col_norms = xp.maximum(_col_norms, 1e-20)
1611
+ X_s = X_feat * (float(_n) ** 0.5 / _col_norms)
1612
+ y_c = y_arr - xp.mean(y_arr)
1613
+ _lam_max = float(xp.max(xp.abs(X_s.T @ y_c / _n)))
1614
+ _target_alpha = float(getattr(self._penalty, 'alpha', self.alpha))
1615
+
1616
+ _n_cont = 20
1617
+ _alpha_path = _np.geomspace(
1618
+ max(_lam_max, _target_alpha * 1.1), _target_alpha, _n_cont,
1619
+ )
1620
+ _max_lla_per_step = max(6, getattr(self, '_max_lla_iters', 50) // _n_cont)
1621
+ _saved_mi = self.max_iter
1622
+ _mi_path = [_saved_mi if i == _n_cont - 1 else max(100, _saved_mi // 10)
1623
+ for i in range(_n_cont)]
1624
+
1625
+ # Create penalty factory for group LLA
1626
+ _orig_pen = self._penalty # unwrap SelectivePenalty
1627
+ _groups = getattr(_orig_pen, '_group_indices', None)
1628
+ _pen_alpha = float(_orig_pen.alpha)
1629
+
1630
+ # Create penalty object once; reuse via set_weights() to avoid
1631
+ # repeated _init_groups() + object creation overhead.
1632
+ _adaptive_pen = AdaptiveGroupLassoPenalty(
1633
+ groups=_groups, alpha=_pen_alpha,
1634
+ )
1635
+ def _group_lla_factory(weights_np):
1636
+ # lla_weights returns per-coordinate; compute per-group weights
1637
+ # as the norm of the per-coordinate weights within each group
1638
+ _gw = np.array([
1639
+ float(np.sqrt(np.sum(weights_np[idx] ** 2))) if len(idx) > 0 else 0.0
1640
+ for idx in _groups
1641
+ ])
1642
+ _adaptive_pen.set_weights(_gw)
1643
+ return _adaptive_pen
1644
+
1645
+ X_orig = X_work[:, :p] if self._effective_intercept else X_work
1646
+ coef_np, intercept, n_iter = fista_lla_path(
1647
+ self._loss, self._penalty,
1648
+ X_orig, y_arr,
1649
+ alpha_path=_alpha_path,
1650
+ max_lla_per_step=_max_lla_per_step,
1651
+ lla_tol=getattr(self, '_lla_tol', 1e-6),
1652
+ max_iter=_mi_path,
1653
+ tol=self.tol,
1654
+ fit_intercept=self._effective_intercept,
1655
+ sample_weight=sample_weight,
1656
+ lla_penalty_factory=_group_lla_factory,
1657
+ )
1658
+ # fista_lla_path returns numpy, convert back to backend-native
1659
+ if self._effective_intercept:
1660
+ params = xp.concatenate([xp.asarray(coef_np), xp.asarray([intercept])])
1661
+ else:
1662
+ params = xp.asarray(coef_np)
1663
+ elif _pen_name == "group_lasso":
1664
+ # Block CD for group_lasso: use GPU-native solver on GPU backends.
1665
+ if backend_name != "numpy":
1666
+ coef_gpu, intercept, n_iter = self._block_cd_group_lasso_gpu(
1667
+ pen, X_work, y_arr, init, backend_name,
1668
+ )
1669
+ if self._effective_intercept:
1670
+ from statgpu.backends._utils import _get_xp as _get_xp_fn
1671
+ from statgpu.backends._array_ops import _xp_asarray as _xp_asarray_fn
1672
+ _xp = _get_xp_fn(backend_name)
1673
+ _int_arr = _xp_asarray_fn([intercept], coef_gpu.dtype, coef_gpu)
1674
+ params = _xp.concatenate([coef_gpu, _int_arr])
1675
+ else:
1676
+ params = coef_gpu
1677
+ else:
1678
+ coef_np, intercept, n_iter = self._block_cd_group_lasso(
1679
+ pen, X_work, y_arr, init,
1680
+ )
1681
+ if self._effective_intercept:
1682
+ params = np.concatenate([coef_np, [intercept]])
1683
+ else:
1684
+ params = coef_np
1685
+ elif solver_name == "auto":
1686
+ # For smooth penalties (l2, elasticnet with low l1_ratio),
1687
+ # fista_bb with BB step sizes converges much more reliably
1688
+ # than standard FISTA with Nesterov momentum + proximal l2.
1689
+ _is_smooth = (_pen_name == "l2") or (
1690
+ _pen_name == "elasticnet" and
1691
+ float(getattr(pen, 'l1_ratio', 1.0)) < 0.5
1692
+ )
1693
+ if _is_smooth:
1694
+ params, n_iter = fista_bb_solver(
1695
+ self._loss, pen, X_work, y_arr,
1696
+ max_iter=self.max_iter, tol=self.tol,
1697
+ init_coef=init, sample_weight=sample_weight,
1698
+ )
1699
+ else:
1700
+ params, n_iter = fista_solver(
1701
+ self._loss, pen, X_work, y_arr,
1702
+ max_iter=self.max_iter, tol=self.tol,
1703
+ init_coef=init, sample_weight=sample_weight,
1704
+ )
1705
+ elif solver_name == "fista":
1706
+ params, n_iter = fista_solver(
1707
+ self._loss, pen, X_work, y_arr,
1708
+ max_iter=self.max_iter, tol=self.tol,
1709
+ init_coef=init, sample_weight=sample_weight,
1710
+ )
1711
+ elif solver_name == "fista_bb":
1712
+ params, n_iter = fista_bb_solver(
1713
+ self._loss, pen, X_work, y_arr,
1714
+ max_iter=self.max_iter, tol=self.tol,
1715
+ init_coef=init, sample_weight=sample_weight,
1716
+ )
1717
+ elif solver_name == "admm":
1718
+ params, n_iter = admm_solver(
1719
+ self._loss, pen, X_work, y_arr,
1720
+ max_iter=self.max_iter,
1721
+ tol=self.tol, rho=1.0, adaptive_rho=True,
1722
+ init_coef=init, sample_weight=sample_weight,
1723
+ )
1724
+ elif solver_name == "newton":
1725
+ params, n_iter = newton_solver(
1726
+ self._loss, pen, X_work, y_arr,
1727
+ max_iter=self.max_iter, tol=self.tol,
1728
+ init_coef=init, sample_weight=sample_weight,
1729
+ )
1730
+ elif solver_name == "lbfgs":
1731
+ params, n_iter = lbfgs_solver(
1732
+ self._loss, pen, X_work, y_arr,
1733
+ max_iter=self.max_iter, tol=self.tol,
1734
+ init_coef=init, sample_weight=sample_weight,
1735
+ )
1736
+ else:
1737
+ raise ValueError(f"Unsupported solver: {solver_name}")
1738
+
1739
+ params_np = _to_numpy(params)
1740
+ self.n_iter_ = n_iter
1741
+ if self._effective_intercept:
1742
+ self.coef_ = params_np[:p]
1743
+ self.intercept_ = float(params_np[p])
1744
+ self._params = np.concatenate([[self.intercept_], self.coef_])
1745
+ else:
1746
+ self.coef_ = params_np.copy()
1747
+ self.intercept_ = 0.0
1748
+ self._params = self.coef_.copy()
1749
+ self._df_resid = self._nobs - (
1750
+ X_arr.shape[1] + (1 if self._effective_intercept else 0)
1751
+ )
1752
+ if backend_name == "cupy":
1753
+ self._cleanup_cuda_memory()
1754
+ elif backend_name == "torch":
1755
+ self._cleanup_torch_memory()
1756
+
1757
+ def _fit_irls_backend(self, X, y, sample_weight=None, backend_name="numpy"):
1758
+ """Fit smooth L2 GLM via IRLS on the selected backend."""
1759
+ from statgpu.glm_core._irls import IRLSSolver
1760
+
1761
+ if str(getattr(self._penalty, "name", self.penalty)).lower() != "l2":
1762
+ raise ValueError("solver='irls' only supports L2 penalties.")
1763
+
1764
+ from statgpu.backends._utils import _get_xp, xp_asarray
1765
+ _xp = _get_xp(backend_name)
1766
+ X_arr = xp_asarray(X, dtype=_xp.float64, xp=_xp, ref_arr=X if not isinstance(X, np.ndarray) else np.zeros(1))
1767
+ y_arr = xp_asarray(y, dtype=_xp.float64, xp=_xp, ref_arr=X_arr)
1768
+ n_samples = X_arr.shape[0]
1769
+ if self._effective_intercept:
1770
+ X_work = self._column_stack(
1771
+ [self._ones(X_arr.shape[0], backend_name, X_arr), X_arr],
1772
+ backend_name,
1773
+ )
1774
+ else:
1775
+ X_work = X_arr
1776
+
1777
+ # Respect CV warm starts first. IRLS uses [intercept, coef...] while
1778
+ # the FISTA design stores the intercept as the final column.
1779
+ _loss_name = getattr(self._loss, 'name', '')
1780
+ init_coef = None
1781
+ init_features = getattr(self, '_init_coef', None)
1782
+ if init_features is not None:
1783
+ init_features_np = np.asarray(init_features, dtype=np.float64).ravel()
1784
+ if self._effective_intercept:
1785
+ init_intercept = float(getattr(self, '_init_intercept', 0.0) or 0.0)
1786
+ init_coef_np = np.concatenate([[init_intercept], init_features_np])
1787
+ else:
1788
+ init_coef_np = init_features_np
1789
+ if backend_name == "cupy":
1790
+ import cupy as cp
1791
+ init_coef = cp.asarray(init_coef_np, dtype=cp.float64)
1792
+ elif backend_name == "torch":
1793
+ import torch
1794
+ init_coef = torch.as_tensor(
1795
+ init_coef_np,
1796
+ dtype=torch.float64,
1797
+ device=X_work.device,
1798
+ )
1799
+ else:
1800
+ init_coef = init_coef_np
1801
+
1802
+ # Otherwise warm-start intercept for GLM losses whose default eta=0
1803
+ # can be far from the intercept-only optimum.
1804
+ _log_link_losses = ("gamma", "poisson", "inverse_gaussian",
1805
+ "negative_binomial", "tweedie")
1806
+ if init_coef is None and self._effective_intercept and (
1807
+ _loss_name in _log_link_losses or _loss_name == "logistic"
1808
+ ):
1809
+ _y_mean = float(np.mean(_to_numpy(y_arr)))
1810
+ if _loss_name == "logistic":
1811
+ _y_mean = float(np.clip(_y_mean, 1e-3, 1.0 - 1e-3))
1812
+ _int_init = np.log(_y_mean / (1.0 - _y_mean))
1813
+ else:
1814
+ _int_init = np.log(max(_y_mean, 1e-3))
1815
+ n_feat = X_work.shape[1]
1816
+ init_coef_np = np.zeros(n_feat)
1817
+ init_coef_np[0] = _int_init
1818
+ if backend_name == "cupy":
1819
+ import cupy as cp
1820
+ init_coef = cp.asarray(init_coef_np)
1821
+ elif backend_name == "torch":
1822
+ import torch
1823
+ init_coef = torch.from_numpy(init_coef_np).to(X_work.device)
1824
+ else:
1825
+ init_coef = init_coef_np
1826
+
1827
+ solver = IRLSSolver(
1828
+ self._family_for_loss(), max_iter=self.max_iter, tol=self.tol
1829
+ )
1830
+ params, n_iter = solver.fit(
1831
+ X_work, y_arr,
1832
+ sample_weight=sample_weight,
1833
+ ridge_alpha=float(n_samples * self.alpha),
1834
+ ridge_penalize_intercept=False if self._effective_intercept else True,
1835
+ backend=backend_name,
1836
+ init_coef=init_coef,
1837
+ )
1838
+
1839
+ params_np = _to_numpy(params)
1840
+ self.n_iter_ = n_iter
1841
+ if self._effective_intercept:
1842
+ self.intercept_ = float(params_np[0])
1843
+ self.coef_ = params_np[1:]
1844
+ self._params = np.concatenate([[self.intercept_], self.coef_])
1845
+ else:
1846
+ self.intercept_ = 0.0
1847
+ self.coef_ = params_np.copy()
1848
+ self._params = self.coef_.copy()
1849
+ self._df_resid = self._nobs - (
1850
+ X_arr.shape[1] + (1 if self._effective_intercept else 0)
1851
+ )
1852
+ if backend_name == "cupy":
1853
+ self._cleanup_cuda_memory()
1854
+ elif backend_name == "torch":
1855
+ self._cleanup_torch_memory()
1856
+
1857
+ def _cleanup_cuda_memory(self):
1858
+ """Free CuPy memory pool."""
1859
+ if not self.gpu_memory_cleanup:
1860
+ return
1861
+ try:
1862
+ import cupy as cp
1863
+ cp.get_default_memory_pool().free_all_blocks()
1864
+ cp.get_default_pinned_memory_pool().free_all_blocks()
1865
+ except Exception:
1866
+ pass
1867
+
1868
+ def _cleanup_torch_memory(self):
1869
+ """Free Torch memory pool."""
1870
+ if not self.gpu_memory_cleanup:
1871
+ return
1872
+ try:
1873
+ import torch
1874
+ if torch.cuda.is_available():
1875
+ torch.cuda.empty_cache()
1876
+ except Exception:
1877
+ pass