statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,553 @@
1
+ """
2
+ Group MCP penalty.
3
+
4
+ Breheny & Huang 2009 (grpreg). Non-convex group penalty: applies MCP
5
+ concavity to the L2 norm of each feature group.
6
+
7
+ Penalty:
8
+ P(w) = sum_g MCP(||w_g||_2; alpha * sqrt(p_g), gamma)
9
+
10
+ where MCP(t; lambda, gamma) is the element-wise MCP penalty.
11
+ """
12
+
13
+ __all__ = ["GroupMCPPenalty"]
14
+
15
+ from typing import Optional, List, Union
16
+ import numpy as np
17
+ from statgpu.penalties._base import Penalty
18
+ from statgpu.penalties._group_lasso import _vector_norm, _to_backend_array, _backend_zeros, _batched_group_norms, _get_xp
19
+
20
+ # ---- torch.compile lazy-loader for vectorized MCP proximal ---------
21
+ _GROUP_MCP_PROXIMAL_TORCH_COMPILED = None
22
+
23
+
24
+ def _get_group_mcp_torch_compiled():
25
+ global _GROUP_MCP_PROXIMAL_TORCH_COMPILED
26
+ if _GROUP_MCP_PROXIMAL_TORCH_COMPILED is not None:
27
+ return _GROUP_MCP_PROXIMAL_TORCH_COMPILED
28
+ from statgpu.penalties import _torch_compile_ok
29
+ if not _torch_compile_ok():
30
+ _GROUP_MCP_PROXIMAL_TORCH_COMPILED = None
31
+ return None
32
+ try:
33
+ import torch
34
+ def _prox(w_mat, sqrt_pg, alpha, step, gamma):
35
+ t_g = alpha * sqrt_pg * step
36
+ gamma_alpha_g = gamma * alpha * sqrt_pg
37
+ norms = torch.linalg.norm(w_mat, dim=1)
38
+ mask_zero = norms <= t_g
39
+ mask_shrink = (norms > t_g) & (norms <= gamma_alpha_g)
40
+ denom = norms * (1.0 - step / gamma)
41
+ denom = torch.where(mask_shrink, denom, torch.ones_like(denom))
42
+ scale_shrink = (norms - t_g) / denom
43
+ scale = torch.where(mask_shrink, scale_shrink, 1.0)
44
+ scale = torch.where(mask_zero, 0.0, scale)
45
+ return (w_mat * scale[:, None]).reshape(-1)
46
+ _GROUP_MCP_PROXIMAL_TORCH_COMPILED = torch.compile(
47
+ _prox, dynamic=True, mode='reduce-overhead'
48
+ )
49
+ except Exception:
50
+ _GROUP_MCP_PROXIMAL_TORCH_COMPILED = None
51
+ return _GROUP_MCP_PROXIMAL_TORCH_COMPILED
52
+
53
+
54
+ class GroupMCPPenalty(Penalty):
55
+ """Group MCP penalty.
56
+
57
+ Parameters
58
+ ----------
59
+ alpha : float, default=1.0
60
+ Regularization strength.
61
+ gamma : float, default=3.0
62
+ MCP concavity parameter. Larger gamma gives less bias (closer to
63
+ group lasso). Must be > 1.
64
+ groups : list of lists, or 1D array-like
65
+ Group membership specification.
66
+
67
+ Notes
68
+ -----
69
+ Group MCP is **non-convex** (``is_convex=False``), optimized via LLA
70
+ (Local Linear Approximation). The objective function may contain multiple
71
+ local minima. Different solvers or different initializations can converge
72
+ to different local minima with comparable objective values — a coefficient
73
+ ``max|diff|`` up to ~1e-2 across runs is expected and does not indicate a
74
+ bug.
75
+ """
76
+
77
+ name = "group_mcp"
78
+ is_convex = False
79
+ supports_group = True
80
+
81
+ def __init__(
82
+ self,
83
+ alpha: float = 1.0,
84
+ gamma: float = 3.0,
85
+ groups=None,
86
+ ):
87
+ if not np.isfinite(alpha) or alpha <= 0.0:
88
+ raise ValueError("alpha must be a finite positive scalar for group MCP penalty")
89
+ if not np.isfinite(gamma) or gamma <= 1.0:
90
+ raise ValueError("gamma must be a finite scalar greater than 1 for group MCP penalty")
91
+ self.alpha = alpha
92
+ self.gamma = gamma
93
+ self._group_indices = None
94
+ self._sqrt_pg = None
95
+ self._n_groups = 0
96
+ self._all_equal_size = False
97
+ self._is_contiguous = False
98
+ self._group_size_uniform = None
99
+ self._flat_indices = None
100
+
101
+ if groups is not None:
102
+ self._init_groups(groups)
103
+
104
+ def _init_groups(self, groups):
105
+ """Parse group specification into internal format."""
106
+ if isinstance(groups, np.ndarray) and groups.ndim == 1:
107
+ group_ids = np.asarray(groups, dtype=int)
108
+ n_groups = int(group_ids.max() + 1)
109
+ self._group_indices = [
110
+ np.where(group_ids == g)[0] for g in range(n_groups)
111
+ ]
112
+ elif isinstance(groups, (list, tuple)):
113
+ if len(groups) == 0:
114
+ raise ValueError("groups must not be empty")
115
+ if isinstance(groups[0], (list, tuple, np.ndarray)):
116
+ self._group_indices = [
117
+ np.asarray(g, dtype=int) for g in groups
118
+ ]
119
+ else:
120
+ group_ids = np.asarray(groups, dtype=int)
121
+ n_groups = int(group_ids.max() + 1)
122
+ self._group_indices = [
123
+ np.where(group_ids == g)[0] for g in range(n_groups)
124
+ ]
125
+ else:
126
+ raise TypeError(
127
+ f"groups must be list or array, got {type(groups).__name__}"
128
+ )
129
+
130
+ self._group_sizes = np.array(
131
+ [len(g) for g in self._group_indices], dtype=int
132
+ )
133
+ self._sqrt_pg = np.sqrt(self._group_sizes.astype(float))
134
+ self._n_groups = len(self._group_indices)
135
+
136
+ sizes = self._group_sizes
137
+ if len(sizes) > 0:
138
+ unique_sizes = np.unique(sizes)
139
+ self._all_equal_size = len(unique_sizes) == 1
140
+ if self._all_equal_size:
141
+ self._group_size_uniform = int(sizes[0])
142
+
143
+ self._is_contiguous = True
144
+ pos = 0
145
+ for g in range(self._n_groups):
146
+ sz = sizes[g]
147
+ if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
148
+ self._is_contiguous = False
149
+ break
150
+ pos += sz
151
+
152
+ if not self._is_contiguous:
153
+ self._flat_indices = np.concatenate(
154
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
155
+ )
156
+
157
+ # Invalidate cached device tensors for _sqrt_pg
158
+ self._sqrt_pg_torch = None
159
+ self._sqrt_pg_cupy = None
160
+
161
+ # Precompute padded gather/scatter index arrays (for unequal groups)
162
+ if not self._all_equal_size:
163
+ self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
164
+ self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
165
+
166
+ # Precompute feature→group mapping (for gradient/lla_weights vectorization)
167
+ flat_indices = np.concatenate(
168
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
169
+ )
170
+ if flat_indices.size == 0:
171
+ raise ValueError("groups must contain at least one feature index")
172
+ max_idx = int(flat_indices.max())
173
+ expected = max_idx + 1
174
+ unique_idx = np.unique(flat_indices)
175
+ if unique_idx.size != flat_indices.size:
176
+ raise ValueError("groups contain duplicate feature indices")
177
+ if unique_idx.size != expected:
178
+ raise ValueError(
179
+ "groups must cover a dense range of feature indices [0..max_index]"
180
+ )
181
+ self._group_feat_idx = np.empty(expected, dtype=np.int64)
182
+ for g, idx in enumerate(self._group_indices):
183
+ self._group_feat_idx[idx] = g
184
+
185
+ # Invalidate all cached device tensors
186
+ self._padded_row_idx_torch = None
187
+ self._padded_row_idx_cupy = None
188
+ self._padded_col_idx_torch = None
189
+ self._padded_col_idx_cupy = None
190
+ self._flat_indices_torch = None
191
+ self._flat_indices_cupy = None
192
+ self._group_feat_idx_torch = None
193
+ self._group_feat_idx_cupy = None
194
+
195
+ def _get_sqrt_pg(self, xp, w):
196
+ """Cached device tensor for _sqrt_pg."""
197
+ if xp.__name__ == "torch":
198
+ if self._sqrt_pg_torch is None:
199
+ self._sqrt_pg_torch = _to_backend_array(self._sqrt_pg, xp, w)
200
+ return self._sqrt_pg_torch
201
+ else:
202
+ if self._sqrt_pg_cupy is None:
203
+ self._sqrt_pg_cupy = _to_backend_array(self._sqrt_pg, xp, w)
204
+ return self._sqrt_pg_cupy
205
+
206
+ def _get_cached(self, attr_name, xp, w):
207
+ """Get or create cached device tensor for a numpy attribute."""
208
+ backend = "torch" if xp.__name__ == "torch" else "cupy"
209
+ cache_attr = f"_{attr_name}_{backend}"
210
+ cached = getattr(self, cache_attr, None)
211
+ if cached is None:
212
+ cached = _to_backend_array(getattr(self, attr_name), xp, w)
213
+ setattr(self, cache_attr, cached)
214
+ return cached
215
+
216
+ def _get_flat_indices(self, xp, w):
217
+ """Cached device tensor for _flat_indices."""
218
+ if not hasattr(self, '_flat_indices') or self._flat_indices is None:
219
+ return None
220
+ return self._get_cached('_flat_indices', xp, w)
221
+
222
+ def _batched_group_norms_vec(self, coef_feat, xp, w_ref):
223
+ """Vectorized batched group norms using padded fancy indexing."""
224
+ G = self._n_groups
225
+ max_sz = int(self._group_sizes.max())
226
+ padded = _backend_zeros((G, max_sz), xp, dtype=coef_feat.dtype, ref_arr=w_ref)
227
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w_ref)
228
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w_ref)
229
+ if self._is_contiguous:
230
+ padded[row_idx_dev, col_idx_dev] = coef_feat
231
+ else:
232
+ flat_idx_dev = self._get_flat_indices(xp, w_ref)
233
+ padded[row_idx_dev, col_idx_dev] = coef_feat[flat_idx_dev]
234
+ return _vector_norm(padded, xp, dim=1)
235
+
236
+ def _reshape_to_matrix(self, w, xp, G, gs):
237
+ """Reshape w into (G, gs) matrix, handling non-contiguous layouts."""
238
+ p_total = G * gs
239
+ w_feat = w[:p_total] # handle augmented intercept
240
+ if self._is_contiguous:
241
+ return w_feat.reshape(G, gs)
242
+ return w_feat[self._flat_indices].reshape(G, gs)
243
+
244
+ def _scatter_from_flat(self, flat_vals, result, xp):
245
+ """Scatter flat values back, handling non-contiguous layouts."""
246
+ p_total = len(flat_vals)
247
+ if self._is_contiguous:
248
+ result[:p_total] = flat_vals
249
+ else:
250
+ flat_idx = self._get_flat_indices(xp, result)
251
+ result[flat_idx] = flat_vals
252
+
253
+ # ----------------------------------------------------------------
254
+ # Value
255
+ # ----------------------------------------------------------------
256
+
257
+ def value(self, coef) -> float:
258
+ if self._group_indices is None:
259
+ raise ValueError("groups must be set before calling value()")
260
+
261
+ xp = _get_xp(coef)
262
+ is_torch = xp.__name__ == "torch"
263
+ is_cupy = xp.__name__ == "cupy"
264
+
265
+ p_total = int(self._group_sizes.sum())
266
+ coef_feat = coef[:p_total] # handle augmented intercept
267
+
268
+ # Compute all group norms in one batch (stays on device)
269
+ if self._all_equal_size and self._group_size_uniform is not None:
270
+ gs = self._group_size_uniform
271
+ if self._is_contiguous:
272
+ w_mat = coef_feat.reshape(self._n_groups, gs)
273
+ else:
274
+ w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
275
+ norms = _vector_norm(w_mat, xp, dim=1)
276
+ else:
277
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
278
+
279
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
280
+ alpha_g = self.alpha * sqrt_pg
281
+ gamma_alpha_g = self.gamma * alpha_g
282
+
283
+ if is_torch:
284
+ import torch
285
+ mask_small = norms <= gamma_alpha_g
286
+ total = torch.sum(alpha_g[mask_small] * norms[mask_small]
287
+ - norms[mask_small] ** 2 / (2.0 * self.gamma))
288
+ total += torch.sum(0.5 * self.gamma * alpha_g[~mask_small] ** 2)
289
+ return total.item()
290
+ elif is_cupy:
291
+ import cupy as cp
292
+ mask_small = norms <= gamma_alpha_g
293
+ total = cp.sum(alpha_g[mask_small] * norms[mask_small]
294
+ - norms[mask_small] ** 2 / (2.0 * self.gamma))
295
+ total += cp.sum(0.5 * self.gamma * alpha_g[~mask_small] ** 2)
296
+ return float(total)
297
+ else:
298
+ mask_small = norms <= gamma_alpha_g
299
+ total = np.sum(alpha_g[mask_small] * norms[mask_small]
300
+ - norms[mask_small] ** 2 / (2.0 * self.gamma))
301
+ total += np.sum(0.5 * self.gamma * alpha_g[~mask_small] ** 2)
302
+ return float(total)
303
+
304
+ # ----------------------------------------------------------------
305
+ # Gradient
306
+ # ----------------------------------------------------------------
307
+
308
+ def gradient(self, coef) -> np.ndarray:
309
+ if self._group_indices is None:
310
+ raise ValueError("groups must be set before calling gradient()")
311
+
312
+ xp = _get_xp(coef)
313
+ is_torch = xp.__name__ == "torch"
314
+ is_cupy = xp.__name__ == "cupy"
315
+
316
+ p_total = int(self._group_sizes.sum())
317
+ coef_feat = coef[:p_total] # handle augmented intercept
318
+
319
+ # Compute all group norms in one batch
320
+ if self._all_equal_size and self._group_size_uniform is not None:
321
+ gs = self._group_size_uniform
322
+ G = self._n_groups
323
+ if self._is_contiguous:
324
+ w_mat = coef_feat.reshape(G, gs)
325
+ else:
326
+ w_mat = coef_feat[self._flat_indices].reshape(G, gs)
327
+ norms = _vector_norm(w_mat, xp, dim=1)
328
+ else:
329
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
330
+
331
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
332
+ alpha_g = self.alpha * sqrt_pg
333
+ gamma_alpha_g = self.gamma * alpha_g
334
+
335
+ # Fused: single scale_g per group (eliminates intermediate deriv_g + inv_norms_g)
336
+ mask_active = (norms > 0) & (norms <= gamma_alpha_g)
337
+ safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
338
+ scale_g = xp.where(mask_active,
339
+ (alpha_g - norms / self.gamma) / safe_norms,
340
+ 0.0)
341
+
342
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
343
+ grad = xp.zeros_like(coef)
344
+ grad[:p_total] = scale_g[feat_idx] * coef_feat
345
+ return grad
346
+
347
+ # ----------------------------------------------------------------
348
+ # Proximal operator (group MCP)
349
+ # ----------------------------------------------------------------
350
+
351
+ def proximal(self, w, step: float, backend: str = "numpy"):
352
+ """Per-group MCP proximal — vectorized on GPU."""
353
+ if self._group_indices is None:
354
+ raise ValueError("groups must be set before calling proximal()")
355
+
356
+ if backend == "cupy":
357
+ import cupy as cp
358
+ return self._proximal_vectorized(w, step, cp)
359
+ elif backend == "torch":
360
+ import torch
361
+ return self._proximal_vectorized(w, step, torch)
362
+ else:
363
+ return self._proximal_loop(w, step, np)
364
+
365
+ def _proximal_loop(self, w, step, xp):
366
+ step = min(float(step), 0.9 * self.gamma) # defense-in-depth clamping
367
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
368
+ for g, idx in enumerate(self._group_indices):
369
+ w_g = w[idx]
370
+ ng = float(xp.linalg.norm(w_g))
371
+ t_g = self.alpha * self._sqrt_pg[g] * step
372
+ gamma_alpha_g = self.gamma * self.alpha * self._sqrt_pg[g]
373
+
374
+ if ng <= t_g:
375
+ result[idx] = 0.0
376
+ elif t_g < ng <= gamma_alpha_g:
377
+ scale = (ng - t_g) / (ng * (1.0 - step / self.gamma))
378
+ result[idx] = w_g * scale
379
+ else:
380
+ result[idx] = w_g
381
+ return result
382
+
383
+ def _proximal_vectorized(self, w, step, xp):
384
+ """Vectorized group MCP proximal."""
385
+ G = self._n_groups
386
+
387
+ if self._all_equal_size and self._group_size_uniform is not None:
388
+ gs = self._group_size_uniform
389
+ return self._proximal_equal(w, step, xp, G, gs)
390
+
391
+ max_sz = int(self._group_sizes.max())
392
+ return self._proximal_padded(w, step, xp, G, max_sz)
393
+
394
+ def _proximal_equal(self, w, step, xp, G, gs):
395
+ """Equal-size groups: vectorized MCP proximal."""
396
+ # Clamp step to prevent division by zero in denom = norms*(1 - step/gamma)
397
+ step = min(float(step), 0.9 * self.gamma)
398
+ w_mat = self._reshape_to_matrix(w, xp, G, gs)
399
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
400
+
401
+ # Torch compiled fast path
402
+ if xp.__name__ == "torch":
403
+ compiled_fn = _get_group_mcp_torch_compiled()
404
+ if compiled_fn is not None:
405
+ scaled_flat = compiled_fn(w_mat, sqrt_pg_arr, self.alpha, step, self.gamma)
406
+ result = w.clone()
407
+ self._scatter_from_flat(scaled_flat, result, xp)
408
+ return result
409
+
410
+ # Generic vectorized path
411
+ norms = _vector_norm(w_mat, xp, dim=1)
412
+ t_g = self.alpha * sqrt_pg_arr * step # (G,)
413
+ gamma_alpha_g = self.gamma * self.alpha * sqrt_pg_arr # (G,)
414
+
415
+ # Region 1: norm <= t_g → zero
416
+ mask_zero = norms <= t_g
417
+ # Region 2: t_g < norm <= gamma_alpha_g → MCP shrinkage
418
+ mask_shrink = (norms > t_g) & (norms <= gamma_alpha_g)
419
+ # Region 3: norm > gamma_alpha_g → no shrinkage (identity)
420
+
421
+ denom = norms * (1.0 - step / self.gamma)
422
+ denom = xp.where(mask_shrink, denom, xp.ones_like(denom))
423
+ scale_shrink = (norms - t_g) / denom # (G,)
424
+ scale = xp.where(mask_shrink, scale_shrink, 1.0) # (G,)
425
+ scale = xp.where(mask_zero, 0.0, scale)
426
+
427
+ scaled_flat = (w_mat * scale[:, None]).reshape(-1)
428
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
429
+ self._scatter_from_flat(scaled_flat, result, xp)
430
+ return result
431
+
432
+ def _proximal_padded(self, w, step, xp, G, max_sz):
433
+ """Unequal groups: pad, vectorize, unpack."""
434
+ step = min(float(step), 0.9 * self.gamma)
435
+ p_total = int(self._group_sizes.sum())
436
+ w_feat = w[:p_total] # handle augmented intercept
437
+
438
+ # Build padded matrix via fancy indexing — 1 kernel launch
439
+ padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
440
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
441
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
442
+ if self._is_contiguous:
443
+ padded[row_idx_dev, col_idx_dev] = w_feat
444
+ else:
445
+ flat_idx_dev = self._get_flat_indices(xp, w)
446
+ padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
447
+
448
+ norms = _vector_norm(padded, xp, dim=1)
449
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
450
+ t_g = self.alpha * sqrt_pg_arr * step
451
+ gamma_alpha_g = self.gamma * self.alpha * sqrt_pg_arr
452
+
453
+ mask_zero = norms <= t_g
454
+ mask_shrink = (norms > t_g) & (norms <= gamma_alpha_g)
455
+ denom = norms * (1.0 - step / self.gamma)
456
+ denom = xp.where(mask_shrink, denom, xp.ones_like(denom))
457
+ scale_shrink = (norms - t_g) / denom
458
+ scale = xp.where(mask_shrink, scale_shrink, 1.0)
459
+ scale = xp.where(mask_zero, 0.0, scale)
460
+
461
+ padded_scaled = padded * scale[:, None]
462
+
463
+ # Scatter back via fancy indexing — 1 kernel launch
464
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
465
+ if self._is_contiguous:
466
+ result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
467
+ else:
468
+ result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
469
+ return result
470
+
471
+ # ----------------------------------------------------------------
472
+ # LLA weights (for LLA outer loop optimization)
473
+ # ----------------------------------------------------------------
474
+
475
+ def lla_weights(self, coef):
476
+ if self._group_indices is None:
477
+ raise ValueError("groups must be set before calling lla_weights()")
478
+
479
+ xp = _get_xp(coef)
480
+ is_torch = xp.__name__ == "torch"
481
+ is_cupy = xp.__name__ == "cupy"
482
+
483
+ p_total = int(self._group_sizes.sum())
484
+ coef_feat = coef[:p_total] # handle augmented intercept
485
+
486
+ # Compute all group norms in one batch
487
+ if self._all_equal_size and self._group_size_uniform is not None:
488
+ gs = self._group_size_uniform
489
+ if self._is_contiguous:
490
+ w_mat = coef_feat.reshape(self._n_groups, gs)
491
+ else:
492
+ w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
493
+ norms = _vector_norm(w_mat, xp, dim=1)
494
+ else:
495
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
496
+
497
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
498
+ alpha_g = self.alpha * sqrt_pg
499
+ gamma_alpha_g = self.gamma * alpha_g
500
+
501
+ # Per-group derivative weight
502
+ if is_torch:
503
+ import torch
504
+ weight_g = torch.where(
505
+ norms <= gamma_alpha_g,
506
+ torch.clamp(alpha_g - norms / self.gamma, min=0.0),
507
+ torch.zeros_like(norms),
508
+ )
509
+ # Broadcast to per-coordinate
510
+ if self._all_equal_size and self._group_size_uniform is not None:
511
+ gs = self._group_size_uniform
512
+ weights = weight_g.repeat_interleave(gs)
513
+ else:
514
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
515
+ weights = weight_g[feat_idx]
516
+ return weights
517
+ elif is_cupy:
518
+ import cupy as cp
519
+ weight_g = cp.where(
520
+ norms <= gamma_alpha_g,
521
+ cp.maximum(alpha_g - norms / self.gamma, 0.0),
522
+ 0.0,
523
+ )
524
+ if self._all_equal_size and self._group_size_uniform is not None:
525
+ gs = self._group_size_uniform
526
+ weights = cp.repeat(weight_g, gs)
527
+ else:
528
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
529
+ weights = weight_g[feat_idx]
530
+ return weights
531
+ else:
532
+ weight_g = np.where(
533
+ norms <= gamma_alpha_g,
534
+ np.maximum(alpha_g - norms / self.gamma, 0.0),
535
+ 0.0,
536
+ )
537
+ if self._all_equal_size and self._group_size_uniform is not None:
538
+ gs = self._group_size_uniform
539
+ weights = np.repeat(weight_g, gs)
540
+ else:
541
+ weights = weight_g[self._group_feat_idx]
542
+ return weights
543
+
544
+ # ----------------------------------------------------------------
545
+
546
+ def get_params(self) -> dict:
547
+ params = super().get_params()
548
+ params.update({
549
+ "alpha": self.alpha,
550
+ "gamma": self.gamma,
551
+ "n_groups": self._n_groups,
552
+ })
553
+ return params