statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1174 @@
1
+ """Legacy solver/CD methods from _penalized.py.
2
+
3
+ These methods were replaced by newer implementations (FISTA, _fit_loss_backend)
4
+ but are retained for reference and backward compatibility.
5
+
6
+ DO NOT import or use in production code.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import numpy as np
12
+
13
+ # Methods extracted from PenalizedGeneralizedLinearModel:
14
+ def _irls_cd_gpu(self, pen, X_work, y_arr, init, backend_name, _lla_continuation=False):
15
+ """GPU-native IRLS with coordinate descent for GLM + non-smooth penalties.
16
+
17
+ Same algorithm as _irls_cd but keeps all arrays on GPU to avoid
18
+ CPU-GPU transfer overhead. Supports cupy and torch backends.
19
+ """
20
+ from statgpu.backends._array_ops import _xp_copy, _xp_zeros, _xp_asarray
21
+ from statgpu.backends._utils import _get_xp
22
+ xp = _get_xp(backend_name)
23
+
24
+ n, pp = X_work.shape
25
+ p = pp - 1 if self._effective_intercept else pp
26
+
27
+ # Access weights from the original penalty
28
+ _inner = getattr(self, '_penalty', pen)
29
+ _w_np = np.asarray(getattr(_inner, '_weights', np.ones(p)), dtype=float)
30
+ _w = _xp_asarray(_w_np, X_work.dtype, X_work)
31
+ alpha = float(getattr(_inner, 'alpha', self.alpha))
32
+ pen_name = getattr(pen, 'name', '') or getattr(_inner, 'name', '')
33
+
34
+ # SCAD/MCP parameters (guard against division-by-zero)
35
+ a_scad = float(getattr(_inner, 'a', 3.7)) if pen_name == "scad" else 0.0
36
+ if pen_name == "scad":
37
+ a_scad = max(a_scad, 1.0 + 1e-6)
38
+ if abs(a_scad - 2.0) < 1e-6:
39
+ a_scad = 2.0 + 1e-6
40
+ gamma_mcp = float(getattr(_inner, 'gamma', 3.0)) if pen_name == "mcp" else 0.0
41
+ if pen_name == "mcp":
42
+ gamma_mcp = max(gamma_mcp, 1.0 + 1e-6)
43
+
44
+ # Penalty value helper (uses numpy for portability; coef_slice is numpy)
45
+ def _nonconvex_penalty_value(coef_slice, _pen_name, _alpha, _a_scad, _gamma_mcp):
46
+ _abs_b = np.abs(coef_slice)
47
+ if _pen_name == "scad":
48
+ return float(np.sum(np.where(
49
+ _abs_b <= _alpha, _alpha * _abs_b,
50
+ np.where(_abs_b <= _a_scad * _alpha,
51
+ (_a_scad * _alpha * _abs_b - 0.5 * (coef_slice**2 + _alpha**2)) / (_a_scad - 1.0),
52
+ 0.5 * (_a_scad + 1.0) * _alpha**2))))
53
+ if _pen_name == "mcp":
54
+ return float(np.sum(np.where(
55
+ _abs_b <= _gamma_mcp * _alpha,
56
+ _alpha * _abs_b - 0.5 * coef_slice**2 / _gamma_mcp,
57
+ 0.5 * _gamma_mcp * _alpha**2)))
58
+ return 0.0
59
+
60
+ if init is not None:
61
+ if isinstance(init, np.ndarray):
62
+ beta = _xp_asarray(init, X_work.dtype, X_work)
63
+ else:
64
+ beta = _xp_copy(init)
65
+ else:
66
+ beta = _xp_zeros(pp, X_work.dtype, X_work)
67
+
68
+ loss_name = self._loss.name
69
+ _is_glm = (loss_name != "squared_error")
70
+
71
+ # Continuation path for SCAD/MCP
72
+ _cont_path = [alpha]
73
+ if pen_name in ("scad", "mcp") and not _lla_continuation:
74
+ _y_np = _to_numpy(y_arr)
75
+ if loss_name == "logistic":
76
+ _p0 = np.clip(np.mean(_y_np), 1e-3, 1 - 1e-3)
77
+ _resid = _y_np - _p0
78
+ elif loss_name == "poisson":
79
+ _mu0 = max(float(np.mean(_y_np)), 1e-3)
80
+ _resid = _y_np - _mu0
81
+ elif loss_name == "gamma":
82
+ _mu0 = max(float(np.mean(_y_np)), 1e-3)
83
+ _resid = (_y_np - _mu0) / _mu0
84
+ else:
85
+ _resid = _y_np - np.mean(_y_np)
86
+ _X_np = _to_numpy(X_work)
87
+ _xty = np.abs(_X_np[:, :p].T @ _resid)
88
+ _xnorm_sq = np.sum(_X_np[:, :p] ** 2, axis=0)
89
+ _xnorm_sq = np.maximum(_xnorm_sq, 1e-20)
90
+ _lam_max = float(np.max(_xty / _xnorm_sq))
91
+ if _lam_max > alpha * 1.1:
92
+ _n_cont = 100
93
+ _cont_path = np.geomspace(_lam_max, alpha, _n_cont)
94
+
95
+ _n_cd_sweeps_base = 1 if _is_glm else min(self.max_iter, 200)
96
+ _n_outer_base = self.max_iter if _is_glm else 1
97
+
98
+ # Precompute X^T X diagonal for squared_error
99
+ if not _is_glm:
100
+ d = _xp_zeros((n,), X_work.dtype, X_work) + 1.0 # ones on correct device
101
+ z = y_arr
102
+ XDX_diag = xp.sum(d[:, None] * X_work ** 2, axis=0)
103
+
104
+ for _cont_idx, _cont_alpha in enumerate(_cont_path):
105
+ if len(_cont_path) > 1:
106
+ alpha = float(_cont_alpha)
107
+ _is_last = (_cont_idx == len(_cont_path) - 1)
108
+ _n_cd_sweeps = _n_cd_sweeps_base if _is_last else 20
109
+ if _is_glm:
110
+ _n_outer = _n_outer_base if _is_last else min(20, _n_outer_base)
111
+ else:
112
+ _n_outer = _n_outer_base
113
+ else:
114
+ _n_cd_sweeps = _n_cd_sweeps_base
115
+ _n_outer = _n_outer_base
116
+
117
+ it = -1
118
+ for it in range(_n_outer):
119
+ beta_old = beta.clone() if backend_name == "torch" else beta.copy()
120
+
121
+ # Compute objective before CD for step-halving (GLM only)
122
+ _obj_before = None
123
+ if _is_glm:
124
+ try:
125
+ _obj_before = float(xp.sum(self._loss.per_sample_value(X_work, y_arr, beta_old)))
126
+ _obj_before += _nonconvex_penalty_value(
127
+ _to_numpy(beta_old[:p]) if backend_name != "numpy" else beta_old[:p],
128
+ pen_name, alpha, a_scad, gamma_mcp)
129
+ except Exception:
130
+ _obj_before = None
131
+
132
+ if _is_glm:
133
+ eta = X_work @ beta
134
+ if loss_name == "logistic":
135
+ mu = 1.0 / (1.0 + _exp(-_clip(eta, -500, 500)))
136
+ mu = _clip(mu, 1e-15, 1.0 - 1e-15)
137
+ d = mu * (1.0 - mu)
138
+ z = eta + (y_arr - mu) / d
139
+ elif loss_name == "poisson":
140
+ mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
141
+ d = mu
142
+ z = eta + (y_arr - mu) / d
143
+ elif loss_name == "gamma":
144
+ mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
145
+ d = _xp_zeros((n,), X_work.dtype, X_work) + 1.0
146
+ z = eta + (y_arr - mu) / mu
147
+ elif loss_name == "inverse_gaussian":
148
+ mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
149
+ d = 1.0 / mu
150
+ z = eta + (y_arr - mu) / mu
151
+ elif loss_name == "negative_binomial":
152
+ mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
153
+ theta_nb = float(getattr(self._loss, 'alpha', 1.0))
154
+ d = mu / (1.0 + mu / theta_nb)
155
+ z = eta + (y_arr - mu) / d
156
+ elif loss_name == "tweedie":
157
+ mu = _clip(_exp(_clip(eta, -500, 500)), 1e-15, None)
158
+ tweedie_p = float(getattr(self._loss, 'power', 1.5))
159
+ d = mu ** tweedie_p
160
+ d = _clip(d, 1e-15, None)
161
+ z = eta + (y_arr - mu) / (d * mu)
162
+ else:
163
+ grad = self._loss.gradient(X_work, y_arr, beta)
164
+ d = _xp_zeros((n,), X_work.dtype, X_work) + 1.0
165
+ z = eta - grad * n
166
+ XDX_diag = xp.sum(d[:, None] * X_work ** 2, axis=0)
167
+
168
+ # Effective sample size for correct normalization with sample weights
169
+ n_eff = float(xp.sum(d))
170
+
171
+ r = z - X_work @ beta
172
+
173
+ # Precompute active mask and vectorized penalty weights
174
+ _active = XDX_diag >= 1e-20
175
+ _v_all = XDX_diag / n_eff
176
+ _v_safe = xp.where(_active, _v_all, 1.0) # avoid division by zero
177
+ if pen_name in ("adaptive_l1", "adaptive_lasso"):
178
+ _l1_all = alpha * _w # shape (p,)
179
+
180
+ for _cd in range(_n_cd_sweeps):
181
+ # --- Vectorized block coordinate descent ---
182
+ # 1. Batch gradient: rho_all = X' (d * r) + XDX_diag * beta
183
+ rho_all = X_work.T @ (d * r) + XDX_diag * beta
184
+ w_all = rho_all / (n_eff * _v_safe) # un-penalized solution
185
+
186
+ # 2. Save old beta for residual update
187
+ old_beta = beta
188
+
189
+ # 3. Vectorized thresholding (penalty-specific)
190
+ if self._effective_intercept:
191
+ new_beta = xp.zeros_like(beta)
192
+ w_feat = w_all[:p]
193
+ else:
194
+ w_feat = w_all
195
+ new_beta = xp.zeros_like(beta)
196
+
197
+ if pen_name in ("adaptive_l1", "adaptive_lasso"):
198
+ aw = xp.abs(w_feat)
199
+ new_beta_feat = xp.sign(w_feat) * xp.maximum(aw - _l1_all, 0.0)
200
+ elif pen_name == "scad":
201
+ aw = xp.abs(w_feat)
202
+ l1 = alpha
203
+ new_beta_feat = xp.where(
204
+ aw > a_scad * l1, w_feat,
205
+ xp.where(
206
+ aw > l1,
207
+ xp.sign(w_feat) * ((a_scad - 1.0) * aw - a_scad * l1) / (a_scad - 2.0),
208
+ 0.0,
209
+ ),
210
+ )
211
+ elif pen_name == "mcp":
212
+ aw = xp.abs(w_feat)
213
+ l1 = alpha
214
+ new_beta_feat = xp.where(
215
+ aw > gamma_mcp * l1, w_feat,
216
+ xp.where(
217
+ aw > l1,
218
+ xp.sign(w_feat) * (aw - l1) / (1.0 - 1.0 / gamma_mcp),
219
+ 0.0,
220
+ ),
221
+ )
222
+ else:
223
+ # lasso / elasticnet (pure L1)
224
+ aw = xp.abs(w_feat)
225
+ new_beta_feat = xp.sign(w_feat) * xp.maximum(aw - alpha, 0.0)
226
+
227
+ # Zero out degenerate columns
228
+ if self._effective_intercept:
229
+ new_beta[:p] = new_beta_feat * _active[:p]
230
+ new_beta[p:] = w_all[p:] # intercept: no penalty
231
+ else:
232
+ new_beta = new_beta_feat * _active
233
+
234
+ # 4. Residual update (single matvec instead of p dot products)
235
+ delta = new_beta - old_beta
236
+ r = r - X_work @ delta
237
+ beta = new_beta
238
+
239
+ # 5. Convergence check (single GPU reduction + one sync)
240
+ _max_cd_change = float(xp.max(xp.abs(delta)))
241
+
242
+ if not _is_glm and _max_cd_change < self.tol:
243
+ break
244
+
245
+ # Step-halving for GLM: ensure penalized objective decreases.
246
+ # Mirrors the CPU path (_irls_cd) to prevent IRLS overshooting.
247
+ if _is_glm:
248
+ _obj_after = float(xp.sum(self._loss.per_sample_value(X_work, y_arr, beta)))
249
+ _obj_after += _nonconvex_penalty_value(
250
+ _to_numpy(beta[:p]) if backend_name != "numpy" else beta[:p],
251
+ pen_name, alpha, a_scad, gamma_mcp)
252
+ if _obj_before is not None and _obj_after > _obj_before + 1e-10:
253
+ beta_new_gpu = beta.clone() if backend_name == "torch" else beta.copy()
254
+ for _sh in range(1, 11):
255
+ _frac = 0.5 ** _sh
256
+ beta_sh = beta_old + _frac * (beta_new_gpu - beta_old)
257
+ _obj_sh = float(xp.sum(self._loss.per_sample_value(X_work, y_arr, beta_sh)))
258
+ _obj_sh += _nonconvex_penalty_value(
259
+ _to_numpy(beta_sh[:p]) if backend_name != "numpy" else beta_sh[:p],
260
+ pen_name, alpha, a_scad, gamma_mcp)
261
+ if _obj_sh <= _obj_before + 1e-10:
262
+ beta = beta_sh
263
+ break
264
+ else:
265
+ # All step-halving attempts failed — revert to previous iterate
266
+ beta = beta_old
267
+
268
+ # IRLS-level convergence check
269
+ _delta = float(xp.max(xp.abs(beta[:p] - beta_old[:p])))
270
+ if not _is_glm and _delta < self.tol:
271
+ break
272
+ if _is_glm and len(_cont_path) > 1 and not _is_last:
273
+ if _delta < self.tol * 10:
274
+ break
275
+
276
+ n_iter = it + 1 if _n_outer > 0 else 0
277
+ return beta, n_iter
278
+
279
+
280
+ def _block_cd_group_lasso_gpu_batched(self, pen, X_work, y_arr, init, backend_name):
281
+ """Batched GPU block coordinate descent for group_lasso penalty.
282
+
283
+ Processes all groups in parallel within each iteration to minimize
284
+ kernel launch overhead. Groups of the same size are batched together
285
+ for efficient linear solves.
286
+ """
287
+ from statgpu.backends._array_ops import _xp_copy, _xp_zeros, _xp_asarray, _scalar_tensor
288
+ from statgpu.backends._utils import _get_xp
289
+ xp = _get_xp(backend_name)
290
+
291
+ n, pp = X_work.shape
292
+ p = pp - 1 if self._effective_intercept else pp
293
+ alpha = self.alpha
294
+
295
+ _inner = getattr(self, '_penalty', pen)
296
+ _g_indices = getattr(_inner, '_group_indices', None)
297
+ _sqrt_pg_np = getattr(_inner, '_sqrt_pg', None)
298
+ if _g_indices is None or _sqrt_pg_np is None:
299
+ raise ValueError(
300
+ "group_lasso penalty must have groups set. "
301
+ "Pass groups=... in penalty_kwargs."
302
+ )
303
+ _n_groups = len(_g_indices)
304
+ _sqrt_pg = [float(s) for s in _sqrt_pg_np]
305
+
306
+ # Pre-compute XtX and Xty once
307
+ XtX = X_work.T @ X_work / n
308
+ Xty = (X_work.T @ y_arr.flatten()) / n
309
+
310
+ # Pre-compute XtX blocks for each group
311
+ _XtX_blocks = []
312
+ for g_idx in _g_indices:
313
+ _XtX_blocks.append(XtX[g_idx][:, g_idx])
314
+
315
+ # Group indices by size for batched solving
316
+ _size_groups = {} # size -> list of (group_idx, indices)
317
+ for g, g_idx in enumerate(_g_indices):
318
+ sz = len(g_idx)
319
+ if sz not in _size_groups:
320
+ _size_groups[sz] = []
321
+ _size_groups[sz].append((g, g_idx))
322
+
323
+ if init is not None:
324
+ if isinstance(init, np.ndarray):
325
+ coef = _xp_asarray(init, X_work.dtype, X_work)
326
+ else:
327
+ coef = _xp_copy(init)
328
+ else:
329
+ coef = _xp_zeros(pp, X_work.dtype, X_work)
330
+
331
+ iteration = -1 # ensure defined when max_iter=0
332
+ for iteration in range(self.max_iter):
333
+ coef_old = _xp_copy(coef)
334
+
335
+ # Process groups by size for batched solving
336
+ for sz, size_groups in _size_groups.items():
337
+ n_batch = len(size_groups)
338
+ if n_batch == 0:
339
+ continue
340
+
341
+ # Collect indices for all groups of this size
342
+ all_indices = []
343
+ batch_g_indices = []
344
+ for g, g_idx in size_groups:
345
+ all_indices.extend(g_idx)
346
+ batch_g_indices.append(g)
347
+
348
+ # Compute rho_g for all groups of this size in one shot
349
+ # rho_g = Xty[g_idx] - XtX[g_idx, :] @ coef + XtX_block[g] @ coef[g_idx]
350
+ # Stack all indices for batched indexing
351
+ idx_arr = _xp_asarray(all_indices, xp.int32 if backend_name == "cupy" else None, X_work)
352
+ # Compute XtX[g_idx, :] @ coef for all groups at once
353
+ XtX_coef = XtX[idx_arr, :] @ coef # shape: (n_batch * sz,)
354
+ # Compute Xty for all groups
355
+ Xty_all = Xty[idx_arr]
356
+ # Compute block diagonal contributions
357
+ block_contrib = _xp_zeros(Xty_all.shape, Xty_all.dtype, Xty_all)
358
+ for i, (g, g_idx) in enumerate(size_groups):
359
+ block_contrib[i*sz:(i+1)*sz] = _XtX_blocks[g] @ coef[g_idx]
360
+ # rho_g = Xty - XtX_coef + block_contrib
361
+ rho_all = Xty_all - XtX_coef + block_contrib
362
+
363
+ # Solve all group systems in one batched call
364
+ rho_mat = rho_all.reshape(n_batch, sz, 1)
365
+ XtX_batch = xp.stack([_XtX_blocks[g] for g in batch_g_indices])
366
+ try:
367
+ w_all = xp.linalg.solve(XtX_batch, rho_mat) # (n_batch, sz, 1)
368
+ w_all = w_all.reshape(n_batch, sz)
369
+ except Exception:
370
+ w_all = _xp_zeros((n_batch, sz), X_work.dtype, X_work)
371
+
372
+ # Apply soft-thresholding to all groups at once
373
+ _norm_dim = 1 # axis for numpy/cupy, dim for torch (both use 1)
374
+ norms = xp.linalg.norm(w_all, axis=_norm_dim) # (n_batch,)
375
+ thresh = _xp_asarray(
376
+ [alpha * _sqrt_pg[g] for g in batch_g_indices],
377
+ X_work.dtype, X_work,
378
+ )
379
+ scale = xp.where(norms > thresh, 1.0 - thresh / (norms + 1e-12), 0.0)
380
+
381
+ # Write back coefficients
382
+ for i, (g, g_idx) in enumerate(size_groups):
383
+ coef[g_idx] = w_all[i] * scale[i]
384
+
385
+ if self._effective_intercept:
386
+ coef[pp - 1] = float(xp.mean(y_arr - X_work[:, :p] @ coef[:p]))
387
+
388
+ _max_change = float(xp.max(xp.abs(coef - coef_old)))
389
+ if _max_change < self.tol:
390
+ break
391
+
392
+ n_iter = iteration + 1
393
+
394
+ if self._effective_intercept:
395
+ beta = coef[:p]
396
+ intercept = float(coef[p])
397
+ else:
398
+ beta = coef
399
+ intercept = 0.0
400
+
401
+ return beta, intercept, n_iter
402
+
403
+
404
+ def _cd_elasticnet(self, pen, X_work, y_arr, init):
405
+ """Coordinate descent for elasticnet penalty (squared_error loss).
406
+
407
+ Matches R glmnet's CD algorithm for elasticnet:
408
+ beta_j = S(rho_j, alpha*l1_ratio*n) / (X_j'X_j + alpha*(1-l1_ratio)*n)
409
+ """
410
+ import numpy as np
411
+
412
+ n, pp = X_work.shape
413
+ p = pp - 1 if self._effective_intercept else pp
414
+ alpha = self.alpha
415
+ l1_ratio = getattr(pen, 'l1_ratio', getattr(self, 'l1_ratio', 0.5))
416
+
417
+ XtX = X_work.T @ X_work
418
+ Xty = X_work.T @ y_arr.flatten()
419
+ X_sq_norms = np.diag(XtX)
420
+
421
+ if init is not None:
422
+ coef = np.array(init, dtype=np.float64)
423
+ else:
424
+ coef = np.zeros(pp, dtype=np.float64)
425
+
426
+ thresh = alpha * l1_ratio * n
427
+
428
+ iteration = -1 # ensure defined when max_iter=0
429
+ for iteration in range(self.max_iter):
430
+ coef_old = coef.copy()
431
+
432
+ for j in range(p):
433
+ rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
434
+ if X_sq_norms[j] > 1e-10:
435
+ st = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0)
436
+ coef[j] = st / (X_sq_norms[j] + alpha * (1 - l1_ratio) * n)
437
+ else:
438
+ coef[j] = 0.0
439
+
440
+ if self._effective_intercept:
441
+ coef[pp - 1] = np.mean(y_arr - X_work[:, :p] @ coef[:p])
442
+
443
+ if np.max(np.abs(coef - coef_old)) < self.tol:
444
+ break
445
+
446
+ n_iter = iteration + 1
447
+
448
+ if self._effective_intercept:
449
+ beta = coef[:p]
450
+ intercept = float(coef[p])
451
+ else:
452
+ beta = coef
453
+ intercept = 0.0
454
+
455
+ return beta, intercept, n_iter
456
+
457
+
458
+ def _cd_l1(self, pen, X_work, y_arr, init):
459
+ """Coordinate descent for L1 (lasso) penalty (squared_error loss).
460
+
461
+ Matches R glmnet's CD algorithm:
462
+ beta_j = S(rho_j, alpha*n) / X_j'X_j
463
+ """
464
+ import numpy as np
465
+
466
+ n, pp = X_work.shape
467
+ p = pp - 1 if self._effective_intercept else pp
468
+ alpha = self.alpha
469
+
470
+ XtX = X_work.T @ X_work
471
+ Xty = X_work.T @ y_arr.flatten()
472
+ X_sq_norms = np.diag(XtX)
473
+
474
+ if init is not None:
475
+ coef = np.array(init, dtype=np.float64)
476
+ else:
477
+ coef = np.zeros(pp, dtype=np.float64)
478
+
479
+ thresh = alpha * n
480
+
481
+ iteration = -1 # ensure defined when max_iter=0
482
+ for iteration in range(self.max_iter):
483
+ coef_old = coef.copy()
484
+
485
+ for j in range(p):
486
+ rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
487
+ if X_sq_norms[j] > 1e-10:
488
+ coef[j] = np.sign(rho_j) * np.maximum(np.abs(rho_j) - thresh, 0) / X_sq_norms[j]
489
+ else:
490
+ coef[j] = 0.0
491
+
492
+ if self._effective_intercept:
493
+ coef[pp - 1] = np.mean(y_arr - X_work[:, :p] @ coef[:p])
494
+
495
+ if np.max(np.abs(coef - coef_old)) < self.tol:
496
+ break
497
+
498
+ n_iter = iteration + 1
499
+
500
+ if self._effective_intercept:
501
+ beta = coef[:p]
502
+ intercept = float(coef[p])
503
+ else:
504
+ beta = coef
505
+ intercept = 0.0
506
+
507
+ return beta, intercept, n_iter
508
+
509
+
510
+ def _fit_cpu_loss(self, X, y, sample_weight=None, solver="fista"):
511
+ """Fit using loss-aware solver (FISTA with arbitrary loss).
512
+
513
+ For GLM losses (logistic, poisson) with intercept, augments X with
514
+ a column of ones and uses a selective penalty (no penalty on intercept)
515
+ to converge to the correct joint optimum.
516
+ """
517
+ from statgpu.solvers import fista_solver
518
+
519
+ X_arr = np.asarray(X)
520
+ y_arr = np.asarray(y)
521
+
522
+ if self.loss in ("logistic", "poisson") and self._effective_intercept:
523
+ # Augment X with intercept column
524
+ X_aug = np.column_stack([X_arr, np.ones(X_arr.shape[0])])
525
+ p = X_arr.shape[1]
526
+ pen = self._penalty
527
+
528
+ from statgpu.linear_model.penalized._base import SelectivePenalty
529
+ singleton = SelectivePenalty()
530
+ singleton.configure(self._penalty, p, "numpy")
531
+
532
+ full_coef, n_iter = fista_solver(
533
+ self._loss, singleton, X_aug, y_arr,
534
+ max_iter=self.max_iter, tol=self.tol,
535
+ init_coef=None, sample_weight=sample_weight,
536
+ )
537
+
538
+ self.coef_ = full_coef[:p]
539
+ self.intercept_ = float(full_coef[p])
540
+ self.n_iter_ = n_iter
541
+ elif self._effective_intercept:
542
+ # Squared error: center X and y, fit once
543
+ X_arr = X_arr - X_arr.mean(axis=0)
544
+ y_arr = y_arr - y_arr.mean()
545
+
546
+ coef, n_iter = fista_solver(
547
+ self._loss, self._penalty, X_arr, y_arr,
548
+ max_iter=self.max_iter, tol=self.tol,
549
+ init_coef=None, sample_weight=sample_weight,
550
+ )
551
+
552
+ self.coef_ = coef
553
+ self.n_iter_ = n_iter
554
+ self.intercept_ = float(np.mean(y) - np.mean(X, axis=0) @ coef)
555
+ else:
556
+ coef, n_iter = fista_solver(
557
+ self._loss, self._penalty, X_arr, y_arr,
558
+ max_iter=self.max_iter, tol=self.tol,
559
+ init_coef=None, sample_weight=sample_weight,
560
+ )
561
+
562
+ self.coef_ = coef
563
+ self.n_iter_ = n_iter
564
+ self.intercept_ = 0.0
565
+
566
+ self._df_resid = self._nobs - (X.shape[1] + (1 if self._effective_intercept else 0))
567
+
568
+
569
+ def _fit_cpu_irls(self, X, y, sample_weight=None):
570
+ """Fit using IRLS for smooth penalty + smooth loss (e.g., Logistic/Poisson + L2).
571
+
572
+ Each IRLS iteration:
573
+ 1. Compute working response z and weights W
574
+ 2. Solve: (X'WX + n*alpha*I) params = X'Wz
575
+ """
576
+ from statgpu.glm_core._irls import IRLSSolver
577
+ from statgpu.glm_core._family import (
578
+ Binomial, Poisson, Gaussian, Gamma,
579
+ InverseGaussian, NegativeBinomial, Tweedie,
580
+ )
581
+
582
+ X_arr = np.asarray(X)
583
+ y_arr = np.asarray(y)
584
+ n_samples = X_arr.shape[0]
585
+
586
+ # Add intercept column if needed
587
+ if self._effective_intercept:
588
+ X_arr = np.column_stack([np.ones(X_arr.shape[0]), X_arr])
589
+
590
+ # L2 penalty: for objective min loss/n + alpha*0.5*||w||^2,
591
+ # IRLS uses unnormalized X'WX, so ridge = n * alpha.
592
+ # Don't penalize the intercept column (matches sklearn/FISTA behavior).
593
+ ridge_alpha = float(n_samples * self.alpha)
594
+ ridge_penalize_intercept = False if self._effective_intercept else True
595
+
596
+ # Select family
597
+ if self.loss == "logistic":
598
+ family = Binomial()
599
+ elif self.loss == "poisson":
600
+ family = Poisson()
601
+ elif self.loss == "gamma":
602
+ family = Gamma()
603
+ elif self.loss == "inverse_gaussian":
604
+ family = InverseGaussian()
605
+ elif self.loss == "negative_binomial":
606
+ family = NegativeBinomial()
607
+ elif self.loss == "tweedie":
608
+ family = Tweedie()
609
+ else:
610
+ family = Gaussian()
611
+
612
+ solver = IRLSSolver(family, max_iter=self.max_iter, tol=self.tol)
613
+ params, n_iter = solver.fit(
614
+ X_arr, y_arr, sample_weight=sample_weight,
615
+ ridge_alpha=ridge_alpha,
616
+ ridge_penalize_intercept=ridge_penalize_intercept,
617
+ backend="numpy",
618
+ )
619
+
620
+ self.n_iter_ = n_iter
621
+
622
+ if self._effective_intercept:
623
+ self.intercept_ = float(params[0])
624
+ self.coef_ = params[1:]
625
+ self._params = np.concatenate([[self.intercept_], np.asarray(self.coef_)])
626
+ else:
627
+ self.intercept_ = 0.0
628
+ self.coef_ = params.copy()
629
+ self._params = np.asarray(self.coef_).copy()
630
+
631
+ self._df_resid = self._nobs - (X.shape[1] + (1 if self._effective_intercept else 0))
632
+
633
+
634
+
635
+ # --- _irls_cd (dead code, moved from _penalized.py) ---
636
+ def _irls_cd(self, pen, X_work, y_arr, init, _lla_continuation=False):
637
+ """IRLS with coordinate descent for GLM + non-smooth penalties.
638
+
639
+ Matches R glmnet/ncvreg algorithm: outer IRLS loop computes working
640
+ response and weights, inner CD loop solves the weighted penalized
641
+ least squares subproblem with per-coordinate thresholds.
642
+ Supports: adaptive_l1, scad, mcp.
643
+ """
644
+ import numpy as np
645
+
646
+ n, pp = X_work.shape
647
+ p = pp - 1 if self._effective_intercept else pp
648
+
649
+ # Access weights from the original penalty (not the SelectivePenalty wrapper)
650
+ _inner = getattr(self, '_penalty', pen)
651
+ _w = np.asarray(getattr(_inner, '_weights', np.ones(p)), dtype=float)
652
+ # Read alpha from the penalty object. The threshold per coordinate
653
+ # is alpha * _w[j] where _w has mean=1 (matching R glmnet convention).
654
+ alpha = float(getattr(_inner, 'alpha', self.alpha))
655
+ _nf = float(getattr(_inner, '_norm_factor', 1.0))
656
+ pen_name = getattr(pen, 'name', '') or getattr(_inner, 'name', '')
657
+
658
+ # SCAD/MCP parameters (guard against division-by-zero)
659
+ a_scad = float(getattr(_inner, 'a', 3.7)) if pen_name == "scad" else 0.0
660
+ if pen_name == "scad":
661
+ a_scad = max(a_scad, 1.0 + 1e-6)
662
+ if abs(a_scad - 2.0) < 1e-6:
663
+ a_scad = 2.0 + 1e-6
664
+ gamma_mcp = float(getattr(_inner, 'gamma', 3.0)) if pen_name == "mcp" else 0.0
665
+ if pen_name == "mcp":
666
+ gamma_mcp = max(gamma_mcp, 1.0 + 1e-6)
667
+
668
+ if init is not None:
669
+ beta = np.asarray(init, dtype=float).copy()
670
+ else:
671
+ beta = np.zeros(pp)
672
+
673
+ loss_name = self._loss.name
674
+ _is_glm = (loss_name != "squared_error")
675
+
676
+ def _nonconvex_penalty_value(coef_slice, _pen_name, _alpha, _a_scad, _gamma_mcp):
677
+ """Compute SCAD/MCP penalty value for a coefficient vector."""
678
+ _abs_b = np.abs(coef_slice)
679
+ if _pen_name == "scad":
680
+ return float(np.sum(np.where(
681
+ _abs_b <= _alpha, _alpha * _abs_b,
682
+ np.where(_abs_b <= _a_scad * _alpha,
683
+ (_a_scad * _alpha * _abs_b - 0.5 * (coef_slice**2 + _alpha**2)) / (_a_scad - 1.0),
684
+ 0.5 * (_a_scad + 1.0) * _alpha**2))))
685
+ if _pen_name == "mcp":
686
+ return float(np.sum(np.where(
687
+ _abs_b <= _gamma_mcp * _alpha,
688
+ _alpha * _abs_b - 0.5 * coef_slice**2 / _gamma_mcp,
689
+ 0.5 * _gamma_mcp * _alpha**2)))
690
+ return 0.0
691
+
692
+ # Continuation path for SCAD/MCP: trace the solution from lambda_max
693
+ # down to the target alpha, matching R ncvreg's pathwise approach.
694
+ # Without this, solving directly at the target alpha can converge to
695
+ # a different local minimum than ncvreg (non-convex penalties have
696
+ # multiple local minima that depend on the starting point).
697
+ # Skip when _lla_continuation=True (outer _fit_lla handles the path).
698
+ _cont_path = [alpha]
699
+ if pen_name in ("scad", "mcp") and not _lla_continuation:
700
+ # lambda_max = max(|X_j^T resid| / ||X_j||^2) at the null model.
701
+ # For squared_error: resid = y - mean(y)
702
+ # For GLM: resid = (y - mu0) / mu0 (working residual at null)
703
+ if loss_name == "logistic":
704
+ _p0 = np.clip(np.mean(y_arr), 1e-3, 1 - 1e-3)
705
+ _resid = y_arr - _p0
706
+ elif loss_name == "poisson":
707
+ _mu0 = max(float(np.mean(y_arr)), 1e-3)
708
+ _resid = y_arr - _mu0
709
+ elif loss_name == "gamma":
710
+ _mu0 = max(float(np.mean(y_arr)), 1e-3)
711
+ _resid = (y_arr - _mu0) / _mu0
712
+ else:
713
+ _resid = y_arr - np.mean(y_arr)
714
+ _xty = np.abs(X_work[:, :p].T @ _resid)
715
+ _xnorm_sq = np.sum(X_work[:, :p] ** 2, axis=0)
716
+ _xnorm_sq = np.maximum(_xnorm_sq, 1e-20)
717
+ _lam_max = float(np.max(_xty / _xnorm_sq))
718
+ if _lam_max > alpha * 1.1:
719
+ _n_cont = 100 # match ncvreg's default nlambda
720
+ _cont_path = np.geomspace(_lam_max, alpha, _n_cont)
721
+
722
+ # For GLM losses, do ONE CD sweep per IRLS iteration (matching
723
+ # R ncvreg/glmnet). The IRLS outer loop handles convergence.
724
+ # For squared_error, use the convergence-based CD loop since
725
+ # there is no outer IRLS loop.
726
+ _n_cd_sweeps_base = 1 if _is_glm else min(self.max_iter, 200)
727
+ # For squared_error, the outer IRLS loop is redundant (d=1, z=y
728
+ # are constant). Run the outer loop only once.
729
+ _n_outer_base = self.max_iter if _is_glm else 1
730
+
731
+ # For squared_error, d/z/XDX_diag are constant across continuation
732
+ # steps — compute once before the loop.
733
+ if not _is_glm:
734
+ d = np.ones(n)
735
+ z = y_arr
736
+ XDX_diag = np.sum(d[:, None] * X_work ** 2, axis=0)
737
+
738
+ for _cont_idx, _cont_alpha in enumerate(_cont_path):
739
+ # Update alpha for this continuation step
740
+ if len(_cont_path) > 1:
741
+ alpha = float(_cont_alpha)
742
+ _is_last = (_cont_idx == len(_cont_path) - 1)
743
+ _n_cd_sweeps = _n_cd_sweeps_base if _is_last else 20
744
+ # For GLM with continuation: limit IRLS iterations on
745
+ # non-final steps. ncvreg does ~10 IRLS per lambda value.
746
+ if _is_glm:
747
+ _n_outer = _n_outer_base if _is_last else min(20, _n_outer_base)
748
+ else:
749
+ _n_outer = _n_outer_base
750
+ else:
751
+ _n_cd_sweeps = _n_cd_sweeps_base
752
+ _n_outer = _n_outer_base
753
+
754
+ it = -1
755
+ for it in range(_n_outer):
756
+ beta_old = beta.copy()
757
+
758
+ if _is_glm:
759
+ eta = X_work @ beta
760
+ if loss_name == "logistic":
761
+ mu = 1.0 / (1.0 + np.exp(-np.clip(eta, -500, 500)))
762
+ mu = np.clip(mu, 1e-15, 1.0 - 1e-15)
763
+ d = mu * (1.0 - mu)
764
+ z = eta + (y_arr - mu) / d
765
+ elif loss_name == "poisson":
766
+ mu = np.exp(np.clip(eta, -500, 500))
767
+ mu = np.maximum(mu, 1e-15)
768
+ d = mu
769
+ z = eta + (y_arr - mu) / d
770
+ elif loss_name == "gamma":
771
+ mu = np.exp(np.clip(eta, -500, 500))
772
+ mu = np.maximum(mu, 1e-15)
773
+ d = np.ones(n)
774
+ z = eta + (y_arr - mu) / mu
775
+ elif loss_name == "inverse_gaussian":
776
+ # V(mu) = mu^3, log link g'(mu) = 1/mu
777
+ # IRLS weight: w = 1/(V(mu) * [g'(mu)]^2) = 1/(mu^3 * 1/mu^2) = 1/mu
778
+ # Working response: z = eta + (y - mu) * g'(mu) = eta + (y - mu)/mu
779
+ mu = np.exp(np.clip(eta, -500, 500))
780
+ mu = np.maximum(mu, 1e-15)
781
+ d = 1.0 / mu
782
+ z = eta + (y_arr - mu) / mu
783
+ elif loss_name == "negative_binomial":
784
+ mu = np.exp(np.clip(eta, -500, 500))
785
+ mu = np.maximum(mu, 1e-15)
786
+ theta_nb = float(getattr(self._loss, 'alpha', 1.0))
787
+ d = mu / (1.0 + mu / theta_nb)
788
+ z = eta + (y_arr - mu) / d
789
+ elif loss_name == "tweedie":
790
+ mu = np.exp(np.clip(eta, -500, 500))
791
+ mu = np.maximum(mu, 1e-15)
792
+ tweedie_p = float(getattr(self._loss, 'power', 1.5))
793
+ d = mu ** tweedie_p
794
+ d = np.maximum(d, 1e-15)
795
+ z = eta + (y_arr - mu) / (d * mu)
796
+ else:
797
+ grad = self._loss.gradient(X_work, y_arr, beta)
798
+ d = np.ones(n)
799
+ z = eta - grad * n
800
+ XDX_diag = np.sum(d[:, None] * X_work ** 2, axis=0)
801
+
802
+ # Effective sample size: use sum(d) for correct normalization
803
+ # when sample weights are present (d already includes sw scaling).
804
+ n_eff = float(np.sum(d))
805
+
806
+ r = z - X_work @ beta
807
+
808
+ # Compute penalized objective before CD (for step-halving)
809
+ if _is_glm:
810
+ # Use full design matrix (including intercept) for correct objective
811
+ _obj_before = float(self._loss.value(X_work, y_arr, beta))
812
+ _obj_before += _nonconvex_penalty_value(beta[:p], pen_name, alpha, a_scad, gamma_mcp)
813
+
814
+ for _cd in range(_n_cd_sweeps):
815
+ _max_cd_change = 0.0
816
+ for j in range(pp):
817
+ if XDX_diag[j] < 1e-20:
818
+ beta[j] = 0.0
819
+ continue
820
+
821
+ rho_j = np.dot(d * X_work[:, j], r) + XDX_diag[j] * beta[j]
822
+ old_bj = beta[j]
823
+
824
+ u_j = rho_j / n_eff
825
+ v_j = XDX_diag[j] / n_eff
826
+
827
+ if j >= p:
828
+ beta[j] = u_j / v_j
829
+ elif pen_name in ("adaptive_l1", "adaptive_lasso"):
830
+ l1 = alpha * _w[j]
831
+ w_j = u_j / v_j
832
+ if w_j > l1:
833
+ beta[j] = (w_j - l1)
834
+ elif w_j < -l1:
835
+ beta[j] = (w_j + l1)
836
+ else:
837
+ beta[j] = 0.0
838
+ elif pen_name == "scad":
839
+ l1 = alpha
840
+ w_j = u_j / v_j
841
+ aw = np.abs(w_j)
842
+ if aw > a_scad * l1:
843
+ beta[j] = w_j
844
+ elif aw > l1:
845
+ beta[j] = np.sign(w_j) * ((a_scad - 1.0) * aw - a_scad * l1) / (a_scad - 2.0)
846
+ else:
847
+ beta[j] = 0.0
848
+ elif pen_name == "mcp":
849
+ l1 = alpha
850
+ w_j = u_j / v_j
851
+ aw = np.abs(w_j)
852
+ if aw > gamma_mcp * l1:
853
+ beta[j] = w_j
854
+ elif aw > l1:
855
+ beta[j] = np.sign(w_j) * (aw - l1) / (1.0 - 1.0 / gamma_mcp)
856
+ else:
857
+ beta[j] = 0.0
858
+ else:
859
+ l1 = alpha
860
+ w_j = u_j / v_j
861
+ if w_j > l1:
862
+ beta[j] = (w_j - l1)
863
+ elif w_j < -l1:
864
+ beta[j] = (w_j + l1)
865
+ else:
866
+ beta[j] = 0.0
867
+
868
+ if beta[j] != old_bj:
869
+ r += X_work[:, j] * (old_bj - beta[j])
870
+ _cd_change = abs(beta[j] - old_bj)
871
+ if _cd_change > _max_cd_change:
872
+ _max_cd_change = _cd_change
873
+
874
+ # Inner CD convergence check (only for squared_error)
875
+ if not _is_glm and _max_cd_change < self.tol:
876
+ break
877
+
878
+ # Step-halving for GLM: ensure penalized objective decreases.
879
+ # ncvreg uses step-halving to prevent IRLS overshooting.
880
+ if _is_glm:
881
+ _obj_after = float(self._loss.value(X_work, y_arr, beta))
882
+ _obj_after += _nonconvex_penalty_value(beta[:p], pen_name, alpha, a_scad, gamma_mcp)
883
+ if _obj_after > _obj_before + 1e-10:
884
+ # Step-halving: interpolate between old and new beta
885
+ # beta_sh = beta_old + 0.5^k * (beta_new - beta_old)
886
+ beta_new = beta.copy()
887
+ for _sh in range(1, 11):
888
+ _frac = 0.5 ** _sh
889
+ beta[:] = beta_old + _frac * (beta_new - beta_old)
890
+ _obj_after = float(self._loss.value(X_work, y_arr, beta))
891
+ _obj_after += _nonconvex_penalty_value(beta[:p], pen_name, alpha, a_scad, gamma_mcp)
892
+ if _obj_after <= _obj_before + 1e-10:
893
+ break
894
+
895
+ # IRLS-level convergence check.
896
+ _delta = np.max(np.abs(beta[:p] - beta_old[:p]))
897
+ if not _is_glm and _delta < self.tol:
898
+ break
899
+ # For GLM with continuation: early exit on convergence
900
+ # for non-final steps (avoids wasting iterations).
901
+ if _is_glm and len(_cont_path) > 1 and not _is_last:
902
+ if _delta < self.tol * 10:
903
+ break
904
+
905
+ n_iter = (it + 1) if _n_outer > 0 else 0
906
+ return beta, n_iter
907
+
908
+
909
+ # --- _fit_lla (dead code, moved from _penalized.py) ---
910
+ def _fit_lla(self, X, y, sample_weight, backend_name, init_coef=None):
911
+ """Fit non-convex penalty via Local Linear Approximation.
912
+
913
+ Outer loop reweights the non-convex penalty as per-coordinate
914
+ weighted L1. Each inner iteration solves a convex problem
915
+ (ADMM for squared-error, FISTA for GLM) with the current weights.
916
+
917
+ A **continuation path** is used for all losses: alpha is stepped
918
+ down geometrically from 15× the target to the target (8 steps).
919
+ Without this, small coefficients from the init receive weak L1
920
+ weights (= P'(|coef|) ≈ alpha) and survive the inner solve,
921
+ producing too many non-zeros. Starting from a larger alpha and
922
+ stepping down forces coefficients to cross the SCAD/MCP transition
923
+ region (alpha .. a·alpha) where the two penalties differ — the
924
+ same strategy used internally by R's ncvreg.
925
+
926
+ For the inner loop the penalty is temporarily swapped for an
927
+ ``AdaptiveL1Penalty`` whose per-coordinate weights are set from
928
+ ``penalty.lla_weights(coef)``.
929
+ """
930
+ n_features = X.shape[1]
931
+
932
+ if init_coef is not None:
933
+ coef_lla = np.asarray(init_coef, dtype=float).copy()
934
+ elif self._penalty.requires_init:
935
+ coef_lla = np.zeros(n_features)
936
+ else:
937
+ coef_lla = self._fit_initial(X, y, backend_name=backend_name)
938
+
939
+ # For GLM + SCAD/MCP direct IRLS-CD path, override init to zeros.
940
+ # R's ncvreg starts from lambda_max with all-zero coefficients and
941
+ # warm-starts down the continuation path. The L2-penalized GLM
942
+ # init gives large coefficients that cause numerical overflow in
943
+ # the IRLS working response when eta is extreme.
944
+ _pen_name_init = str(getattr(self._penalty, 'name', '')).lower()
945
+ _is_glm_scad_mcp = (self.loss != "squared_error") and _pen_name_init in ("scad", "mcp")
946
+ _is_scad_mcp = _pen_name_init in ("scad", "mcp")
947
+ if _is_scad_mcp:
948
+ coef_lla = np.zeros(n_features)
949
+
950
+ from statgpu.penalties._adaptive_l1 import AdaptiveL1Penalty
951
+
952
+ # ADMM inner solver was used for squared_error CPU path for cross-backend
953
+ # consistency, but on CPU it is 4000× slower than FISTA (admm_solver
954
+ # recomputes X@w and X.T@g per CG iteration instead of precomputing XtX
955
+ # once). On GPU the cuBLAS matmuls are fast enough that ADMM is
956
+ # competitive. Use fista_bb for CPU (O(p²) gradient with XtX precompute)
957
+ # GLM losses: use fista_bb for early continuation steps (large alpha,
958
+ # small coef — exp(X@coef) ≈ 1, BB steps are safe and 3-10× faster),
959
+ # then switch to fista (backtracking) only for the final step where
960
+ # coefficients may grow large enough to cause exp-link explosion.
961
+ # Gamma is excluded — its gradient scale (1/mu) makes BB step estimates
962
+ # unreliable even at small coefficients.
963
+ saved_cpu_solver = self.cpu_solver
964
+ saved_selected_solver = self._selected_solver
965
+ _is_glm = (self.loss != "squared_error")
966
+ _glm_bb_safe = _is_glm and self.loss in ("poisson", "logistic")
967
+ if _is_glm and not _glm_bb_safe:
968
+ self.cpu_solver = "fista"
969
+ self._selected_solver = "fista"
970
+ elif not _is_glm:
971
+ if _is_scad_mcp:
972
+ # SCAD/MCP uses direct FISTA+proximal (not ADMM)
973
+ self.cpu_solver = "fista_bb"
974
+ self._selected_solver = "fista_bb"
975
+ else:
976
+ # CPU: use fista_bb (precomputes XtX, O(p²) per iter, ~9ms total)
977
+ # GPU: use admm (cuBLAS matmuls, ~40ms total with perfect x-backend consistency)
978
+ if backend_name == "numpy":
979
+ self.cpu_solver = "fista_bb"
980
+ self._selected_solver = "fista_bb"
981
+ else:
982
+ self.cpu_solver = "admm"
983
+ self._selected_solver = "admm"
984
+
985
+ # Continuation path for all losses: start from a larger alpha and
986
+ # step down geometrically to the target. This forces coefficients
987
+ # to cross the SCAD/MCP transition region (alpha .. a·alpha).
988
+ # Squared-error + ADMM uses a wider path (20× / 8 steps) because
989
+ # the OLS init produces many small but non-zero coefficients that
990
+ # need stronger initial shrinkage to match R's ncvreg. GLM losses
991
+ # use a moderate path (10× / 5 steps) to balance sparsity and
992
+ # convergence — larger paths cause FISTA to overshoot.
993
+ import numpy as _np
994
+
995
+ # Compute lambda_max — the smallest penalty where all coefficients are zero.
996
+ # Matches R ncvreg: lambda_max = max_j |sum(x_s_j * resid)| / n
997
+ # on standardized X (||X_j|| = sqrt(n)). The IRLS-CD gradient
998
+ # u_j = rho_j/n equals this at the null model, and the SCAD/MCP
999
+ # threshold is l1 = alpha on u_j.
1000
+ _X_np = _np.asarray(X, dtype=float)
1001
+ _y_np = _np.asarray(y, dtype=float)
1002
+ _n = _X_np.shape[0]
1003
+ # Standardize X to match ncvreg: ||X_j|| = sqrt(n), i.e. mean(x^2) = 1
1004
+ _col_norms = _np.sqrt(_np.sum(_X_np ** 2, axis=0))
1005
+ _col_norms = _np.maximum(_col_norms, 1e-20)
1006
+ _X_s = _X_np * (_np.sqrt(_n) / _col_norms)
1007
+ if self.loss == "logistic":
1008
+ _p0 = _np.clip(_np.mean(_y_np), 1e-3, 1-1e-3)
1009
+ _lam_max = float(_np.max(_np.abs(_X_s.T @ (_y_np - _p0) / _n)))
1010
+ elif self.loss == "poisson":
1011
+ _mu0 = max(float(_np.mean(_y_np)), 1e-3)
1012
+ _lam_max = float(_np.max(_np.abs(_X_s.T @ (_y_np - _mu0) / _n)))
1013
+ elif self.loss == "gamma":
1014
+ _mu0 = max(float(_np.mean(_y_np)), 1e-3)
1015
+ _lam_max = float(_np.max(_np.abs(_X_s.T @ ((_y_np - _mu0) / _mu0) / _n)))
1016
+ elif self.loss == "squared_error":
1017
+ _y_centered = _y_np - _np.mean(_y_np)
1018
+ _lam_max = float(_np.max(_np.abs(_X_s.T @ _y_centered / _n)))
1019
+ else:
1020
+ _lam_max = self.alpha * 15.0 # fallback
1021
+
1022
+ _n_cont = 20 if _is_scad_mcp else 10
1023
+ # Start from lambda_max to match R ncvreg's pathwise approach.
1024
+ # lambda_max is the smallest penalty where all coefficients are zero.
1025
+ _alpha_start = float(_lam_max)
1026
+ _alpha_end = float(self.alpha)
1027
+ if _alpha_start <= 0.0 or _alpha_end <= 0.0:
1028
+ _lo = max(min(_alpha_start, _alpha_end), 1e-12)
1029
+ _hi = max(_alpha_start, _alpha_end, 1e-12)
1030
+ if _hi <= _lo:
1031
+ _alpha_path = _np.full(_n_cont, _hi, dtype=float)
1032
+ else:
1033
+ _alpha_path = _np.linspace(_hi, _lo, _n_cont, dtype=float)
1034
+ _alpha_path[-1] = max(_alpha_end, 1e-12)
1035
+ else:
1036
+ _alpha_path = _np.geomspace(_alpha_start, _alpha_end, _n_cont)
1037
+ _max_lla_per_step = max(6, self._max_lla_iters // _n_cont)
1038
+
1039
+ saved_max_iter = self.max_iter
1040
+
1041
+ try:
1042
+ # squared_error+SCAD/MCP: fused LLA+FISTA path.
1043
+ # Runs entire continuation+LLA+FISTA loop in one tight function
1044
+ # to eliminate per-call overhead (300+ fista_solver calls).
1045
+ if _is_scad_mcp and not _is_glm:
1046
+ from statgpu.solvers import fista_lla_path
1047
+ X_cached = self._to_array(X, backend=backend_name)
1048
+ y_cached = self._to_array(y, backend=backend_name)
1049
+
1050
+ # Build max_iter schedule: early steps need fewer iterations
1051
+ _mi_path = []
1052
+ for _i in range(_n_cont):
1053
+ _is_last = (_i == _n_cont - 1)
1054
+ _mi_path.append(saved_max_iter if _is_last else max(100, saved_max_iter // 10))
1055
+
1056
+ coef_np, intercept, n_iter = fista_lla_path(
1057
+ self._loss, self._penalty,
1058
+ X_cached, y_cached,
1059
+ alpha_path=_alpha_path,
1060
+ max_lla_per_step=_max_lla_per_step,
1061
+ lla_tol=self._lla_tol,
1062
+ max_iter=_mi_path,
1063
+ tol=self.tol,
1064
+ fit_intercept=self._effective_intercept,
1065
+ sample_weight=sample_weight,
1066
+ )
1067
+ coef_lla = coef_np
1068
+ self.coef_ = coef_np
1069
+ self.intercept_ = intercept
1070
+ self.n_iter_ = n_iter
1071
+ self._lla_n_iters_ = _n_cont * _max_lla_per_step
1072
+ else:
1073
+ # Cache GPU arrays once outside the continuation loop
1074
+ X_cached = self._to_array(X, backend=backend_name)
1075
+ y_cached = self._to_array(y, backend=backend_name)
1076
+
1077
+ for _cont_step, _cont_alpha in enumerate(_alpha_path):
1078
+ # Create a copy with the continuation alpha to avoid
1079
+ # mutating the shared penalty object (thread-safety).
1080
+ _pen_step = copy.copy(self._penalty)
1081
+ _pen_step.alpha = float(_cont_alpha)
1082
+
1083
+ _is_last_cont = (_cont_step == _n_cont - 1)
1084
+ if _is_glm_scad_mcp:
1085
+ self.max_iter = 500 if _is_last_cont else 100
1086
+ elif _is_last_cont:
1087
+ self.max_iter = saved_max_iter
1088
+ else:
1089
+ self.max_iter = max(200, saved_max_iter // 3)
1090
+ _is_gamma = (self.loss == "gamma")
1091
+ if _is_gamma:
1092
+ self.max_iter = max(300, self.max_iter // 2)
1093
+ if _glm_bb_safe:
1094
+ self.cpu_solver = "fista_bb"
1095
+ self._selected_solver = "fista_bb"
1096
+
1097
+ if _is_scad_mcp and not _is_glm:
1098
+ # This branch is now handled above by fista_lla_path
1099
+ pass
1100
+ else:
1101
+ for _lla_local in range(_max_lla_per_step):
1102
+ # Compute LLA weights from current estimate
1103
+ lla_w = _pen_step.lla_weights(coef_lla)
1104
+
1105
+ # SelectivePenalty wrapper handles intercept separately
1106
+ # (clips to [-15,15] then sets penalty gradient to 0).
1107
+ # Weights stay at p entries — no intercept padding needed.
1108
+ # lla_weights() already returns alpha-scaled derivative
1109
+ # weights (e.g. SCAD: alpha for |coef| <= alpha).
1110
+ # AdaptiveL1Penalty applies: alpha_inner * weight_j * |coef_j|,
1111
+ # so with alpha_inner=1 and weight=lla_w we get exactly
1112
+ # the LLA penalty: sum_j lla_w_j * |coef_j|.
1113
+ #
1114
+ inner_pen = AdaptiveL1Penalty(alpha=1.0)
1115
+ inner_pen._weights = lla_w
1116
+
1117
+ # Swap penalty (protected by try/finally)
1118
+ # Use copy to avoid thread-safety issues with shared instances
1119
+ import copy
1120
+ orig_penalty = copy.copy(self._penalty)
1121
+ self._penalty = inner_pen
1122
+ try:
1123
+ # Run inner FISTA with warm-start from previous LLA estimate
1124
+ # Use cached arrays to avoid repeated GPU transfers
1125
+ self._init_coef = coef_lla.copy()
1126
+
1127
+ if backend_name == "torch":
1128
+ self._fit_torch(X_cached, y_cached, sample_weight)
1129
+ elif backend_name == "cupy":
1130
+ self._fit_gpu(X_cached, y_cached, sample_weight)
1131
+ else:
1132
+ self._fit_cpu(X_cached, y_cached, sample_weight)
1133
+
1134
+ self._init_coef = None
1135
+ finally:
1136
+ # Restore original penalty even if inner fit raises
1137
+ self._penalty = orig_penalty
1138
+
1139
+ # LLA convergence
1140
+ coef_new = self.coef_.copy()
1141
+ delta = float(np.sum(np.abs(coef_new - coef_lla)))
1142
+ self._lla_n_iters_ = getattr(self, '_lla_n_iters_', 0) + 1
1143
+
1144
+ if delta < self._lla_tol:
1145
+ coef_lla = coef_new
1146
+ break
1147
+
1148
+ coef_lla = coef_new
1149
+
1150
+ # Store final results. For GLM+SCAD/MCP, _fit_cpu/_fit_gpu/_fit_torch
1151
+ # already set self.coef_ and self.intercept_. For squared_error+SCAD/MCP,
1152
+ # _irls_cd returned params but didn't set them on self.
1153
+ if self.coef_ is None and coef_lla is not None:
1154
+ self.coef_ = np.asarray(coef_lla[:X.shape[1]], dtype=float)
1155
+ if self._effective_intercept:
1156
+ X_np = np.asarray(X, dtype=float)
1157
+ y_np = np.asarray(y, dtype=float)
1158
+ if sample_weight is not None:
1159
+ sw_np = np.asarray(sample_weight, dtype=float).ravel()
1160
+ sw_sum = max(float(np.sum(sw_np)), 1e-15)
1161
+ X_wmean = np.sum(X_np * sw_np[:, None], axis=0) / sw_sum
1162
+ y_wmean = float(np.sum(y_np * sw_np)) / sw_sum
1163
+ self.intercept_ = float(y_wmean - X_wmean @ self.coef_)
1164
+ else:
1165
+ self.intercept_ = float(np.mean(y_np) - np.mean(X_np, axis=0) @ self.coef_)
1166
+ else:
1167
+ self.intercept_ = 0.0
1168
+ self._params = np.concatenate([[self.intercept_], self.coef_])
1169
+ self._df_resid = X.shape[0] - (X.shape[1] + (1 if self._effective_intercept else 0))
1170
+ finally:
1171
+ self.cpu_solver = saved_cpu_solver
1172
+ self._selected_solver = saved_selected_solver
1173
+ self.max_iter = saved_max_iter
1174
+