statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,605 @@
1
+ """
2
+ Group SCAD penalty.
3
+
4
+ Breheny & Huang 2009 (grpreg). Non-convex group penalty: applies SCAD
5
+ concavity to the L2 norm of each feature group.
6
+
7
+ Penalty:
8
+ P(w) = sum_g SCAD(||w_g||_2; alpha * sqrt(p_g), a)
9
+
10
+ where SCAD(t; lambda, a) is the element-wise SCAD penalty.
11
+ """
12
+
13
+ __all__ = ["GroupSCADPenalty"]
14
+
15
+ from typing import Optional, List, Union
16
+ import numpy as np
17
+ from statgpu.penalties._base import Penalty
18
+ from statgpu.penalties._group_lasso import _vector_norm, _to_backend_array, _backend_zeros, _batched_group_norms, _get_xp
19
+
20
+ # ---- torch.compile lazy-loader for vectorized SCAD proximal ---------
21
+ _GROUP_SCAD_PROXIMAL_TORCH_COMPILED = None
22
+
23
+
24
+ def _get_group_scad_torch_compiled():
25
+ global _GROUP_SCAD_PROXIMAL_TORCH_COMPILED
26
+ if _GROUP_SCAD_PROXIMAL_TORCH_COMPILED is not None:
27
+ return _GROUP_SCAD_PROXIMAL_TORCH_COMPILED
28
+ from statgpu.penalties import _torch_compile_ok
29
+ if not _torch_compile_ok():
30
+ _GROUP_SCAD_PROXIMAL_TORCH_COMPILED = None
31
+ return None
32
+ try:
33
+ import torch
34
+ def _prox(w_mat, sqrt_pg, alpha, step, a):
35
+ alpha_g = alpha * sqrt_pg
36
+ t_g = alpha_g * step
37
+ a_alpha_g = a * alpha_g
38
+ norms = torch.linalg.norm(w_mat, dim=1)
39
+ mask_r1 = norms <= alpha_g + t_g
40
+ safe_norms = torch.where(norms > 0.0, norms, torch.ones_like(norms))
41
+ scale_r1 = torch.clamp((norms - t_g) / safe_norms, min=0.0)
42
+ mask_r2 = (norms > alpha_g + t_g) & (norms <= a_alpha_g)
43
+ denom = norms * (a - 1.0 - step)
44
+ denom = torch.where(mask_r2, denom, torch.ones_like(denom))
45
+ scale_r2 = ((a - 1.0) * norms - a * t_g) / denom
46
+ scale = torch.where(mask_r1, scale_r1, 1.0)
47
+ scale = torch.where(mask_r2, scale_r2, scale)
48
+ return (w_mat * scale[:, None]).reshape(-1)
49
+ _GROUP_SCAD_PROXIMAL_TORCH_COMPILED = torch.compile(
50
+ _prox, dynamic=True, mode='reduce-overhead'
51
+ )
52
+ except Exception:
53
+ _GROUP_SCAD_PROXIMAL_TORCH_COMPILED = None
54
+ return _GROUP_SCAD_PROXIMAL_TORCH_COMPILED
55
+
56
+
57
+ class GroupSCADPenalty(Penalty):
58
+ """Group SCAD penalty.
59
+
60
+ Parameters
61
+ ----------
62
+ alpha : float, default=1.0
63
+ Regularization strength.
64
+ a : float, default=3.7
65
+ SCAD concavity parameter. Must be > 2.
66
+ groups : list of lists, or 1D array-like
67
+ Group membership specification.
68
+
69
+ Notes
70
+ -----
71
+ Group SCAD is **non-convex** (``is_convex=False``), optimized via LLA
72
+ (Local Linear Approximation). The objective function may contain multiple
73
+ local minima. Different solvers or different initializations can converge
74
+ to different local minima with comparable objective values — a coefficient
75
+ ``max|diff|`` up to ~1e-2 across runs is expected and does not indicate a
76
+ bug.
77
+ """
78
+
79
+ name = "group_scad"
80
+ is_convex = False
81
+ supports_group = True
82
+
83
+ def __init__(
84
+ self,
85
+ alpha: float = 1.0,
86
+ a: float = 3.7,
87
+ groups=None,
88
+ ):
89
+ if not np.isfinite(alpha) or alpha <= 0.0:
90
+ raise ValueError("alpha must be a finite positive scalar for group SCAD penalty")
91
+ if not np.isfinite(a) or a <= 2.0:
92
+ raise ValueError("a must be a finite scalar greater than 2 for group SCAD penalty")
93
+ self.alpha = alpha
94
+ self.a = a
95
+ self._group_indices = None
96
+ self._sqrt_pg = None
97
+ self._n_groups = 0
98
+ self._all_equal_size = False
99
+ self._is_contiguous = False
100
+ self._group_size_uniform = None
101
+ self._flat_indices = None
102
+
103
+ if groups is not None:
104
+ self._init_groups(groups)
105
+
106
+ def _init_groups(self, groups):
107
+ """Parse group specification into internal format."""
108
+ if isinstance(groups, np.ndarray) and groups.ndim == 1:
109
+ group_ids = np.asarray(groups, dtype=int)
110
+ n_groups = int(group_ids.max() + 1)
111
+ self._group_indices = [
112
+ np.where(group_ids == g)[0] for g in range(n_groups)
113
+ ]
114
+ elif isinstance(groups, (list, tuple)):
115
+ if len(groups) == 0:
116
+ raise ValueError("groups must not be empty")
117
+ if isinstance(groups[0], (list, tuple, np.ndarray)):
118
+ self._group_indices = [
119
+ np.asarray(g, dtype=int) for g in groups
120
+ ]
121
+ else:
122
+ group_ids = np.asarray(groups, dtype=int)
123
+ n_groups = int(group_ids.max() + 1)
124
+ self._group_indices = [
125
+ np.where(group_ids == g)[0] for g in range(n_groups)
126
+ ]
127
+ else:
128
+ raise TypeError(
129
+ f"groups must be list or array, got {type(groups).__name__}"
130
+ )
131
+
132
+ self._group_sizes = np.array(
133
+ [len(g) for g in self._group_indices], dtype=int
134
+ )
135
+ self._sqrt_pg = np.sqrt(self._group_sizes.astype(float))
136
+ self._n_groups = len(self._group_indices)
137
+
138
+ sizes = self._group_sizes
139
+ if len(sizes) > 0:
140
+ unique_sizes = np.unique(sizes)
141
+ self._all_equal_size = len(unique_sizes) == 1
142
+ if self._all_equal_size:
143
+ self._group_size_uniform = int(sizes[0])
144
+
145
+ self._is_contiguous = True
146
+ pos = 0
147
+ for g in range(self._n_groups):
148
+ sz = sizes[g]
149
+ if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
150
+ self._is_contiguous = False
151
+ break
152
+ pos += sz
153
+
154
+ if not self._is_contiguous:
155
+ self._flat_indices = np.concatenate(
156
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
157
+ )
158
+
159
+ # Invalidate cached device tensors for _sqrt_pg
160
+ self._sqrt_pg_torch = None
161
+ self._sqrt_pg_cupy = None
162
+
163
+ # Precompute padded gather/scatter index arrays (for unequal groups)
164
+ if not self._all_equal_size:
165
+ self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
166
+ self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
167
+
168
+ # Precompute feature→group mapping (for gradient/lla_weights vectorization)
169
+ flat_indices = np.concatenate(
170
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
171
+ )
172
+ if flat_indices.size == 0:
173
+ raise ValueError("groups must contain at least one feature index")
174
+ max_idx = int(flat_indices.max())
175
+ expected = max_idx + 1
176
+ unique_idx = np.unique(flat_indices)
177
+ if unique_idx.size != flat_indices.size:
178
+ raise ValueError("groups contain duplicate feature indices")
179
+ if unique_idx.size != expected:
180
+ raise ValueError(
181
+ "groups must cover a dense range of feature indices [0..max_index]"
182
+ )
183
+ self._group_feat_idx = np.empty(expected, dtype=np.int64)
184
+ for g, idx in enumerate(self._group_indices):
185
+ self._group_feat_idx[idx] = g
186
+
187
+ # Invalidate all cached device tensors
188
+ self._padded_row_idx_torch = None
189
+ self._padded_row_idx_cupy = None
190
+ self._padded_col_idx_torch = None
191
+ self._padded_col_idx_cupy = None
192
+ self._flat_indices_torch = None
193
+ self._flat_indices_cupy = None
194
+ self._group_feat_idx_torch = None
195
+ self._group_feat_idx_cupy = None
196
+
197
+ def _get_sqrt_pg(self, xp, w):
198
+ """Cached device tensor for _sqrt_pg."""
199
+ if xp.__name__ == "torch":
200
+ if self._sqrt_pg_torch is None:
201
+ self._sqrt_pg_torch = _to_backend_array(self._sqrt_pg, xp, w)
202
+ return self._sqrt_pg_torch
203
+ else:
204
+ if self._sqrt_pg_cupy is None:
205
+ self._sqrt_pg_cupy = _to_backend_array(self._sqrt_pg, xp, w)
206
+ return self._sqrt_pg_cupy
207
+
208
+ def _get_cached(self, attr_name, xp, w):
209
+ """Get or create cached device tensor for a numpy attribute."""
210
+ backend = "torch" if xp.__name__ == "torch" else "cupy"
211
+ cache_attr = f"_{attr_name}_{backend}"
212
+ cached = getattr(self, cache_attr, None)
213
+ if cached is None:
214
+ cached = _to_backend_array(getattr(self, attr_name), xp, w)
215
+ setattr(self, cache_attr, cached)
216
+ return cached
217
+
218
+ def _get_flat_indices(self, xp, w):
219
+ """Cached device tensor for _flat_indices."""
220
+ if not hasattr(self, '_flat_indices') or self._flat_indices is None:
221
+ return None
222
+ return self._get_cached('_flat_indices', xp, w)
223
+
224
+ def _batched_group_norms_vec(self, coef_feat, xp, w_ref):
225
+ """Vectorized batched group norms using padded fancy indexing."""
226
+ G = self._n_groups
227
+ max_sz = int(self._group_sizes.max())
228
+ padded = _backend_zeros((G, max_sz), xp, dtype=coef_feat.dtype, ref_arr=w_ref)
229
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w_ref)
230
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w_ref)
231
+ if self._is_contiguous:
232
+ padded[row_idx_dev, col_idx_dev] = coef_feat
233
+ else:
234
+ flat_idx_dev = self._get_flat_indices(xp, w_ref)
235
+ padded[row_idx_dev, col_idx_dev] = coef_feat[flat_idx_dev]
236
+ return _vector_norm(padded, xp, dim=1)
237
+
238
+ def _reshape_to_matrix(self, w, xp, G, gs):
239
+ p_total = G * gs
240
+ w_feat = w[:p_total] # handle augmented intercept
241
+ if self._is_contiguous:
242
+ return w_feat.reshape(G, gs)
243
+ return w_feat[self._flat_indices].reshape(G, gs)
244
+
245
+ def _scatter_from_flat(self, flat_vals, result, xp):
246
+ p_total = len(flat_vals)
247
+ if self._is_contiguous:
248
+ result[:p_total] = flat_vals
249
+ else:
250
+ flat_idx = self._get_flat_indices(xp, result)
251
+ result[flat_idx] = flat_vals
252
+
253
+ # ----------------------------------------------------------------
254
+ # Value
255
+ # ----------------------------------------------------------------
256
+
257
+ def value(self, coef) -> float:
258
+ if self._group_indices is None:
259
+ raise ValueError("groups must be set before calling value()")
260
+
261
+ xp = _get_xp(coef)
262
+ is_torch = xp.__name__ == "torch"
263
+ is_cupy = xp.__name__ == "cupy"
264
+ a = self.a
265
+
266
+ p_total = int(self._group_sizes.sum())
267
+ coef_feat = coef[:p_total] # handle augmented intercept
268
+
269
+ # Compute all group norms in one batch
270
+ if self._all_equal_size and self._group_size_uniform is not None:
271
+ gs = self._group_size_uniform
272
+ if self._is_contiguous:
273
+ w_mat = coef_feat.reshape(self._n_groups, gs)
274
+ else:
275
+ w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
276
+ norms = _vector_norm(w_mat, xp, dim=1)
277
+ else:
278
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
279
+
280
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
281
+ alpha_g = self.alpha * sqrt_pg
282
+ a_alpha_g = a * alpha_g
283
+
284
+ if is_torch:
285
+ import torch
286
+ mask_r1 = norms <= alpha_g
287
+ mask_r2 = (norms > alpha_g) & (norms <= a_alpha_g)
288
+ # Region 1: alpha_g * ng
289
+ total = torch.sum(alpha_g[mask_r1] * norms[mask_r1])
290
+ # Region 2: -(ng^2 - 2*a*alpha_g*ng + alpha_g^2) / (2*(a-1))
291
+ ng2 = norms[mask_r2]
292
+ ag2 = alpha_g[mask_r2]
293
+ total += torch.sum(-(ng2 * ng2 - 2.0 * a * ag2 * ng2 + ag2 * ag2) / (2.0 * (a - 1.0)))
294
+ # Region 3: (a+1)*alpha_g^2 / 2
295
+ mask_r3 = norms > a_alpha_g
296
+ total += torch.sum((a + 1.0) * alpha_g[mask_r3] ** 2 / 2.0)
297
+ return total.item()
298
+ elif is_cupy:
299
+ import cupy as cp
300
+ mask_r1 = norms <= alpha_g
301
+ mask_r2 = (norms > alpha_g) & (norms <= a_alpha_g)
302
+ total = cp.sum(alpha_g[mask_r1] * norms[mask_r1])
303
+ ng2 = norms[mask_r2]
304
+ ag2 = alpha_g[mask_r2]
305
+ total += cp.sum(-(ng2 * ng2 - 2.0 * a * ag2 * ng2 + ag2 * ag2) / (2.0 * (a - 1.0)))
306
+ mask_r3 = norms > a_alpha_g
307
+ total += cp.sum((a + 1.0) * alpha_g[mask_r3] ** 2 / 2.0)
308
+ return float(total)
309
+ else:
310
+ mask_r1 = norms <= alpha_g
311
+ mask_r2 = (norms > alpha_g) & (norms <= a_alpha_g)
312
+ total = np.sum(alpha_g[mask_r1] * norms[mask_r1])
313
+ ng2 = norms[mask_r2]
314
+ ag2 = alpha_g[mask_r2]
315
+ total += np.sum(-(ng2 * ng2 - 2.0 * a * ag2 * ng2 + ag2 * ag2) / (2.0 * (a - 1.0)))
316
+ mask_r3 = norms > a_alpha_g
317
+ total += np.sum((a + 1.0) * alpha_g[mask_r3] ** 2 / 2.0)
318
+ return float(total)
319
+
320
+ # ----------------------------------------------------------------
321
+ # Gradient
322
+ # ----------------------------------------------------------------
323
+
324
+ def gradient(self, coef) -> np.ndarray:
325
+ if self._group_indices is None:
326
+ raise ValueError("groups must be set before calling gradient()")
327
+
328
+ xp = _get_xp(coef)
329
+ is_torch = xp.__name__ == "torch"
330
+ is_cupy = xp.__name__ == "cupy"
331
+ a = self.a
332
+
333
+ p_total = int(self._group_sizes.sum())
334
+ coef_feat = coef[:p_total] # handle augmented intercept
335
+
336
+ # Compute all group norms in one batch
337
+ if self._all_equal_size and self._group_size_uniform is not None:
338
+ gs = self._group_size_uniform
339
+ G = self._n_groups
340
+ if self._is_contiguous:
341
+ w_mat = coef_feat.reshape(G, gs)
342
+ else:
343
+ w_mat = coef_feat[self._flat_indices].reshape(G, gs)
344
+ norms = _vector_norm(w_mat, xp, dim=1)
345
+ else:
346
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
347
+
348
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
349
+ alpha_g = self.alpha * sqrt_pg
350
+ a_alpha_g = a * alpha_g
351
+
352
+ # Fused: single scale_g per group (eliminates intermediate deriv_g + inv_norms_g)
353
+ # Region 1 (|w|<=alpha): deriv=alpha_g, scale=alpha_g/norm
354
+ # Region 2 (alpha<|w|<=a*alpha): deriv=(a*alpha_g-norm)/(a-1), scale=deriv/norm
355
+ # Region 3 (|w|>a*alpha): scale=0
356
+ safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
357
+ scale_g = xp.where(norms <= alpha_g,
358
+ alpha_g / safe_norms,
359
+ xp.where((norms > alpha_g) & (norms <= a_alpha_g),
360
+ (a * alpha_g - norms) / ((a - 1.0) * safe_norms),
361
+ 0.0))
362
+
363
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
364
+ grad = xp.zeros_like(coef)
365
+ grad[:p_total] = scale_g[feat_idx] * coef_feat
366
+ return grad
367
+
368
+ # ----------------------------------------------------------------
369
+ # Proximal operator (group SCAD)
370
+ # ----------------------------------------------------------------
371
+
372
+ def proximal(self, w, step: float, backend: str = "numpy"):
373
+ """Per-group SCAD proximal — vectorized on GPU."""
374
+ if self._group_indices is None:
375
+ raise ValueError("groups must be set before calling proximal()")
376
+
377
+ if backend == "cupy":
378
+ import cupy as cp
379
+ return self._proximal_vectorized(w, step, cp)
380
+ elif backend == "torch":
381
+ import torch
382
+ return self._proximal_vectorized(w, step, torch)
383
+ else:
384
+ return self._proximal_loop(w, step, np)
385
+
386
+ def _proximal_loop(self, w, step, xp):
387
+ step = min(float(step), 0.9 * (self.a - 1.0)) # defense-in-depth clamping
388
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
389
+ for g, idx in enumerate(self._group_indices):
390
+ w_g = w[idx]
391
+ ng = float(xp.linalg.norm(w_g))
392
+ alpha_g = self.alpha * self._sqrt_pg[g]
393
+ t_g = alpha_g * step
394
+ a_alpha_g = self.a * alpha_g
395
+
396
+ if ng <= alpha_g + t_g:
397
+ if ng > 0:
398
+ result[idx] = w_g * max(0.0, ng - t_g) / ng
399
+ else:
400
+ result[idx] = 0.0
401
+ elif ng > alpha_g + t_g and ng <= a_alpha_g:
402
+ if ng > 1e-15:
403
+ scale = ((self.a - 1.0) * ng - self.a * t_g) / (ng * (self.a - 1.0 - step))
404
+ result[idx] = w_g * scale
405
+ else:
406
+ result[idx] = 0.0
407
+ else:
408
+ result[idx] = w_g
409
+ return result
410
+
411
+ def _proximal_vectorized(self, w, step, xp):
412
+ """Vectorized group SCAD proximal."""
413
+ G = self._n_groups
414
+
415
+ if self._all_equal_size and self._group_size_uniform is not None:
416
+ gs = self._group_size_uniform
417
+ return self._proximal_equal(w, step, xp, G, gs)
418
+
419
+ max_sz = int(self._group_sizes.max())
420
+ return self._proximal_padded(w, step, xp, G, max_sz)
421
+
422
+ def _proximal_equal(self, w, step, xp, G, gs):
423
+ """Equal-size groups: vectorized SCAD proximal (3 regions)."""
424
+ # Clamp step to prevent division by zero in denom = norms*(a-1-step)
425
+ step = min(float(step), 0.9 * (self.a - 1.0))
426
+ w_mat = self._reshape_to_matrix(w, xp, G, gs)
427
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
428
+
429
+ # Torch compiled fast path
430
+ if xp.__name__ == "torch":
431
+ compiled_fn = _get_group_scad_torch_compiled()
432
+ if compiled_fn is not None:
433
+ scaled_flat = compiled_fn(w_mat, sqrt_pg_arr, self.alpha, step, self.a)
434
+ result = w.clone()
435
+ self._scatter_from_flat(scaled_flat, result, xp)
436
+ return result
437
+
438
+ # Generic vectorized path
439
+ norms = _vector_norm(w_mat, xp, dim=1)
440
+ alpha_g = self.alpha * sqrt_pg_arr # (G,)
441
+ t_g = alpha_g * step # (G,)
442
+ a_alpha_g = self.a * alpha_g # (G,)
443
+
444
+ # Region 1: soft-threshold ng <= alpha_g + t_g
445
+ mask_r1 = norms <= alpha_g + t_g
446
+ safe_norms = xp.where(norms > 0.0, norms, xp.ones_like(norms))
447
+ scale_r1 = xp.clip((norms - t_g) / safe_norms, 0.0, None)
448
+
449
+ # Region 2: SCAD intermediate alpha_g + t_g < ng <= a_alpha_g
450
+ mask_r2 = (norms > alpha_g + t_g) & (norms <= a_alpha_g)
451
+ denom = norms * (self.a - 1.0 - step)
452
+ denom = xp.where(mask_r2, denom, xp.ones_like(denom))
453
+ scale_r2 = ((self.a - 1.0) * norms - self.a * t_g) / denom
454
+
455
+ # Region 3: no shrinkage ng > a_alpha_g
456
+ # scale = 1.0 (default)
457
+
458
+ scale = xp.where(mask_r1, scale_r1, 1.0)
459
+ scale = xp.where(mask_r2, scale_r2, scale)
460
+
461
+ scaled_flat = (w_mat * scale[:, None]).reshape(-1)
462
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
463
+ self._scatter_from_flat(scaled_flat, result, xp)
464
+ return result
465
+
466
+ def _proximal_padded(self, w, step, xp, G, max_sz):
467
+ """Unequal groups: pad, vectorize, unpack."""
468
+ step = min(float(step), 0.9 * (self.a - 1.0))
469
+ p_total = int(self._group_sizes.sum())
470
+ w_feat = w[:p_total] # handle augmented intercept
471
+
472
+ # Build padded matrix via fancy indexing — 1 kernel launch
473
+ padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
474
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
475
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
476
+ if self._is_contiguous:
477
+ padded[row_idx_dev, col_idx_dev] = w_feat
478
+ else:
479
+ flat_idx_dev = self._get_flat_indices(xp, w)
480
+ padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
481
+
482
+ norms = _vector_norm(padded, xp, dim=1)
483
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
484
+ alpha_g = self.alpha * sqrt_pg_arr
485
+ t_g = alpha_g * step
486
+ a_alpha_g = self.a * alpha_g
487
+
488
+ mask_r1 = norms <= alpha_g + t_g
489
+ safe_norms = xp.where(norms > 0.0, norms, xp.ones_like(norms))
490
+ scale_r1 = xp.clip((norms - t_g) / safe_norms, 0.0, None)
491
+
492
+ mask_r2 = (norms > alpha_g + t_g) & (norms <= a_alpha_g)
493
+ denom = norms * (self.a - 1.0 - step)
494
+ denom = xp.where(mask_r2, denom, xp.ones_like(denom))
495
+ scale_r2 = ((self.a - 1.0) * norms - self.a * t_g) / denom
496
+
497
+ scale = xp.where(mask_r1, scale_r1, 1.0)
498
+ scale = xp.where(mask_r2, scale_r2, scale)
499
+
500
+ padded_scaled = padded * scale[:, None]
501
+
502
+ # Scatter back via fancy indexing — 1 kernel launch
503
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
504
+ if self._is_contiguous:
505
+ result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
506
+ else:
507
+ result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
508
+ return result
509
+
510
+ # ----------------------------------------------------------------
511
+ # LLA weights (for LLA outer loop optimization)
512
+ # ----------------------------------------------------------------
513
+
514
+ def lla_weights(self, coef):
515
+ if self._group_indices is None:
516
+ raise ValueError("groups must be set before calling lla_weights()")
517
+
518
+ xp = _get_xp(coef)
519
+ is_torch = xp.__name__ == "torch"
520
+ is_cupy = xp.__name__ == "cupy"
521
+ a = self.a
522
+
523
+ p_total = int(self._group_sizes.sum())
524
+ coef_feat = coef[:p_total] # handle augmented intercept
525
+
526
+ # Compute all group norms in one batch
527
+ if self._all_equal_size and self._group_size_uniform is not None:
528
+ gs = self._group_size_uniform
529
+ if self._is_contiguous:
530
+ w_mat = coef_feat.reshape(self._n_groups, gs)
531
+ else:
532
+ w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
533
+ norms = _vector_norm(w_mat, xp, dim=1)
534
+ else:
535
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
536
+
537
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
538
+ alpha_g = self.alpha * sqrt_pg
539
+ a_alpha_g = a * alpha_g
540
+
541
+ # Per-group derivative weight
542
+ if is_torch:
543
+ import torch
544
+ weight_g = torch.where(
545
+ norms <= alpha_g,
546
+ alpha_g,
547
+ torch.where(
548
+ norms <= a_alpha_g,
549
+ (a * alpha_g - norms) / (a - 1.0),
550
+ torch.zeros_like(norms),
551
+ ),
552
+ )
553
+ # Broadcast to per-coordinate
554
+ if self._all_equal_size and self._group_size_uniform is not None:
555
+ gs = self._group_size_uniform
556
+ weights = weight_g.repeat_interleave(gs)
557
+ else:
558
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
559
+ weights = weight_g[feat_idx]
560
+ return weights
561
+ elif is_cupy:
562
+ import cupy as cp
563
+ weight_g = cp.where(
564
+ norms <= alpha_g,
565
+ alpha_g,
566
+ cp.where(
567
+ norms <= a_alpha_g,
568
+ (a * alpha_g - norms) / (a - 1.0),
569
+ 0.0,
570
+ ),
571
+ )
572
+ if self._all_equal_size and self._group_size_uniform is not None:
573
+ gs = self._group_size_uniform
574
+ weights = cp.repeat(weight_g, gs)
575
+ else:
576
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
577
+ weights = weight_g[feat_idx]
578
+ return weights
579
+ else:
580
+ weight_g = np.where(
581
+ norms <= alpha_g,
582
+ alpha_g,
583
+ np.where(
584
+ norms <= a_alpha_g,
585
+ (a * alpha_g - norms) / (a - 1.0),
586
+ 0.0,
587
+ ),
588
+ )
589
+ if self._all_equal_size and self._group_size_uniform is not None:
590
+ gs = self._group_size_uniform
591
+ weights = np.repeat(weight_g, gs)
592
+ else:
593
+ weights = weight_g[self._group_feat_idx]
594
+ return weights
595
+
596
+ # ----------------------------------------------------------------
597
+
598
+ def get_params(self) -> dict:
599
+ params = super().get_params()
600
+ params.update({
601
+ "alpha": self.alpha,
602
+ "a": self.a,
603
+ "n_groups": self._n_groups,
604
+ })
605
+ return params