statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,678 @@
1
+ """
2
+ Group Lasso penalty.
3
+
4
+ Yuan & Lin, JRSSB 2006. Convex penalty that selects groups of features.
5
+
6
+ The penalty is:
7
+ P(w) = alpha * sum_g sqrt(p_g) * ||w_g||_2
8
+
9
+ where w_g is the subvector of w for group g, and p_g is the size of group g.
10
+ """
11
+
12
+ __all__ = ["GroupLassoPenalty", "AdaptiveGroupLassoPenalty"]
13
+
14
+ from typing import Optional, List, Union
15
+ import numpy as np
16
+ from statgpu.penalties._base import Penalty
17
+
18
+ # ---- torch.compile lazy-loader for vectorized proximal on GPU ---------
19
+ _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = None
20
+
21
+
22
+ def _get_group_lasso_torch_compiled_equal():
23
+ """torch.compile'd equal-size group lasso proximal (G,gs)→norms→scale→flat."""
24
+ global _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL
25
+ if _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL is not None:
26
+ return _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL
27
+ from statgpu.penalties import _torch_compile_ok
28
+ if not _torch_compile_ok():
29
+ _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = None
30
+ return None
31
+ try:
32
+ import torch
33
+ def _prox(w_mat, sqrt_pg, alpha, step):
34
+ thresh = alpha * sqrt_pg * step
35
+ norms = torch.linalg.norm(w_mat, dim=1)
36
+ scale = torch.clamp(1.0 - thresh / (norms + 1e-12), min=0.0)
37
+ return (w_mat * scale[:, None]).reshape(-1)
38
+ _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = torch.compile(
39
+ _prox, dynamic=True, mode='reduce-overhead'
40
+ )
41
+ except Exception:
42
+ _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = None
43
+ return _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL
44
+
45
+
46
+ def _vector_norm(x, xp, dim=None):
47
+ """Backend-aware L2 norm along a dimension."""
48
+ if xp.__name__ == "torch":
49
+ return xp.linalg.norm(x, dim=dim) if dim is not None else xp.linalg.norm(x)
50
+ return xp.linalg.norm(x, axis=dim) if dim is not None else xp.linalg.norm(x)
51
+
52
+
53
+ def _to_backend_array(arr, xp, ref_arr=None):
54
+ """Convert numpy array to backend array type."""
55
+ if xp.__name__ == "torch":
56
+ import torch
57
+ arr_np = np.asarray(arr)
58
+ # Preserve int types (needed for indexing), convert others to float64
59
+ if arr_np.dtype.kind in ('i', 'u'):
60
+ t = torch.from_numpy(arr_np)
61
+ else:
62
+ t = torch.from_numpy(arr_np.astype(np.float64))
63
+ if ref_arr is not None:
64
+ t = t.to(device=ref_arr.device)
65
+ return t
66
+ return xp.asarray(arr)
67
+
68
+
69
+ def _backend_zeros(shape, xp, dtype=None, ref_arr=None):
70
+ """Create zeros array on the correct backend."""
71
+ if xp.__name__ == "torch":
72
+ import torch
73
+ t = torch.zeros(shape, dtype=dtype if dtype is not None else torch.float64)
74
+ if ref_arr is not None:
75
+ t = t.to(device=ref_arr.device)
76
+ return t
77
+ return xp.zeros(shape, dtype=dtype)
78
+
79
+
80
+ def _batched_group_norms(coef, group_indices, xp):
81
+ """Compute L2 norms for each group, all on device. Returns (G,) array."""
82
+ norms_list = []
83
+ for idx in group_indices:
84
+ if len(idx) > 0:
85
+ norms_list.append(_vector_norm(coef[idx], xp))
86
+ else:
87
+ if xp.__name__ == "torch":
88
+ norms_list.append(xp.zeros(1, device=coef.device, dtype=coef.dtype)[0])
89
+ elif xp.__name__ == "cupy":
90
+ norms_list.append(xp.zeros(1, dtype=coef.dtype)[0])
91
+ else:
92
+ norms_list.append(0.0)
93
+ if xp.__name__ == "torch":
94
+ return xp.stack(norms_list)
95
+ elif xp.__name__ == "cupy":
96
+ return xp.array(norms_list)
97
+ return np.array(norms_list)
98
+
99
+
100
+ # Use canonical _xp from backends (replaces local _get_xp)
101
+ from statgpu.backends._array_ops import _xp as _get_xp
102
+
103
+
104
+ class GroupLassoPenalty(Penalty):
105
+ """Group Lasso penalty.
106
+
107
+ Parameters
108
+ ----------
109
+ alpha : float, default=1.0
110
+ Regularization strength.
111
+ groups : list of lists, or 1D array-like
112
+ Group membership specification. Two forms accepted:
113
+ - List of lists of feature indices, e.g. [[0,1], [2,3,4]]
114
+ - 1D array of length n_features where each entry is the group ID
115
+ """
116
+
117
+ name = "group_lasso"
118
+ is_convex = True
119
+ supports_group = True
120
+
121
+ def __init__(
122
+ self,
123
+ alpha: float = 1.0,
124
+ groups=None,
125
+ ):
126
+ self.alpha = alpha
127
+ self._group_indices = None
128
+ self._group_sizes = None
129
+ self._all_equal_size = False
130
+ self._is_contiguous = False
131
+ self._group_size_uniform = None
132
+ self._flat_indices = None
133
+
134
+ if groups is not None:
135
+ self._init_groups(groups)
136
+
137
+ def _init_groups(self, groups):
138
+ """Parse group specification into internal format."""
139
+ if isinstance(groups, np.ndarray) and groups.ndim == 1:
140
+ group_ids = np.asarray(groups, dtype=int)
141
+ n_groups = int(group_ids.max() + 1)
142
+ self._group_indices = [
143
+ np.where(group_ids == g)[0] for g in range(n_groups)
144
+ ]
145
+ elif isinstance(groups, (list, tuple)):
146
+ if len(groups) == 0:
147
+ raise ValueError("groups must not be empty")
148
+ if isinstance(groups[0], (list, tuple, np.ndarray)):
149
+ self._group_indices = [
150
+ np.asarray(g, dtype=int) for g in groups
151
+ ]
152
+ else:
153
+ group_ids = np.asarray(groups, dtype=int)
154
+ n_groups = int(group_ids.max() + 1)
155
+ self._group_indices = [
156
+ np.where(group_ids == g)[0] for g in range(n_groups)
157
+ ]
158
+ else:
159
+ raise TypeError(
160
+ f"groups must be list or array, got {type(groups).__name__}"
161
+ )
162
+
163
+ self._group_sizes = np.array(
164
+ [len(g) for g in self._group_indices], dtype=int
165
+ )
166
+ self._sqrt_pg = np.sqrt(self._group_sizes.astype(float))
167
+ self._n_groups = len(self._group_indices)
168
+
169
+ # Detect equal-size groups for fast vectorized path
170
+ sizes = self._group_sizes
171
+ if len(sizes) > 0:
172
+ unique_sizes = np.unique(sizes)
173
+ self._all_equal_size = len(unique_sizes) == 1
174
+ if self._all_equal_size:
175
+ self._group_size_uniform = int(sizes[0])
176
+
177
+ # Check if groups are contiguous [0..p1-1], [p1..p1+p2-1], ...
178
+ self._is_contiguous = True
179
+ pos = 0
180
+ for g in range(self._n_groups):
181
+ sz = sizes[g]
182
+ if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
183
+ self._is_contiguous = False
184
+ break
185
+ pos += sz
186
+
187
+ # Precompute flat indices for gather/scatter (only needed if non-contiguous)
188
+ if not self._is_contiguous:
189
+ self._flat_indices = np.concatenate(
190
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
191
+ )
192
+
193
+ # Invalidate cached device tensors for _sqrt_pg
194
+ self._sqrt_pg_torch = None
195
+ self._sqrt_pg_cupy = None
196
+
197
+ # Precompute padded gather/scatter index arrays (for unequal groups)
198
+ if not self._all_equal_size:
199
+ self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
200
+ self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
201
+
202
+ # Precompute feature→group mapping (for gradient/lla_weights vectorization)
203
+ flat_indices = np.concatenate(
204
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
205
+ )
206
+ if flat_indices.size == 0:
207
+ raise ValueError("groups must contain at least one feature index")
208
+ max_idx = int(flat_indices.max())
209
+ expected = max_idx + 1
210
+ unique_idx = np.unique(flat_indices)
211
+ if unique_idx.size != flat_indices.size:
212
+ raise ValueError("groups contain duplicate feature indices")
213
+ if unique_idx.size != expected:
214
+ # Auto-fill missing indices as single-feature groups
215
+ import warnings
216
+ all_indices = set(range(expected))
217
+ covered = set(unique_idx.tolist())
218
+ missing = sorted(all_indices - covered)
219
+ if missing:
220
+ warnings.warn(
221
+ f"Groups do not cover features {missing}. "
222
+ f"Auto-adding {len(missing)} single-feature groups.",
223
+ UserWarning, stacklevel=2,
224
+ )
225
+ for idx in missing:
226
+ self._group_indices.append([idx])
227
+ flat_indices = np.concatenate(self._group_indices)
228
+ unique_idx = np.unique(flat_indices)
229
+ # Update derived attributes after auto-fill
230
+ self._n_groups = len(self._group_indices)
231
+ self._group_sizes = np.array([len(g) for g in self._group_indices], dtype=np.int64)
232
+ sizes = self._group_sizes
233
+ unique_sizes = np.unique(sizes)
234
+ self._all_equal_size = len(unique_sizes) == 1
235
+ if self._all_equal_size:
236
+ self._group_size_uniform = int(sizes[0])
237
+ # Recompute contiguity
238
+ self._is_contiguous = True
239
+ pos = 0
240
+ for g in range(self._n_groups):
241
+ sz = sizes[g]
242
+ if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
243
+ self._is_contiguous = False
244
+ break
245
+ pos += sz
246
+ if not self._is_contiguous:
247
+ self._flat_indices = np.concatenate(
248
+ [np.asarray(g, dtype=np.int64) for g in self._group_indices]
249
+ )
250
+ # Recompute padded indices for unequal groups
251
+ if not self._all_equal_size:
252
+ self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
253
+ self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
254
+ self._group_feat_idx = np.empty(expected, dtype=np.int64)
255
+ for g, idx in enumerate(self._group_indices):
256
+ self._group_feat_idx[idx] = g
257
+
258
+ # Invalidate all cached device tensors
259
+ self._padded_row_idx_torch = None
260
+ self._padded_row_idx_cupy = None
261
+ self._padded_col_idx_torch = None
262
+ self._padded_col_idx_cupy = None
263
+ self._flat_indices_torch = None
264
+ self._flat_indices_cupy = None
265
+ self._group_feat_idx_torch = None
266
+ self._group_feat_idx_cupy = None
267
+
268
+ # ----------------------------------------------------------------
269
+ # Value
270
+ # ----------------------------------------------------------------
271
+
272
+ def value(self, coef) -> float:
273
+ if self._group_indices is None:
274
+ raise ValueError("groups must be set before calling value()")
275
+
276
+ xp = _get_xp(coef)
277
+ is_torch = xp.__name__ == "torch"
278
+ is_cupy = xp.__name__ == "cupy"
279
+
280
+ p_total = int(self._group_sizes.sum())
281
+ coef_feat = coef[:p_total] # handle augmented intercept
282
+
283
+ # Compute all group norms in one batch (stays on device)
284
+ if self._all_equal_size and self._group_size_uniform is not None:
285
+ gs = self._group_size_uniform
286
+ if self._is_contiguous:
287
+ w_mat = coef_feat.reshape(self._n_groups, gs)
288
+ else:
289
+ w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
290
+ norms = _vector_norm(w_mat, xp, dim=1)
291
+ else:
292
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
293
+
294
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
295
+
296
+ if is_torch:
297
+ return xp.sum(self.alpha * sqrt_pg * norms).item()
298
+ elif is_cupy:
299
+ return float(xp.sum(self.alpha * sqrt_pg * norms))
300
+ else:
301
+ return float(np.sum(self.alpha * sqrt_pg * norms))
302
+
303
+ # ----------------------------------------------------------------
304
+ # Gradient
305
+ # ----------------------------------------------------------------
306
+
307
+ def gradient(self, coef) -> np.ndarray:
308
+ if self._group_indices is None:
309
+ raise ValueError("groups must be set before calling gradient()")
310
+
311
+ xp = _get_xp(coef)
312
+ is_torch = xp.__name__ == "torch"
313
+ is_cupy = xp.__name__ == "cupy"
314
+
315
+ p_total = int(self._group_sizes.sum())
316
+ coef_feat = coef[:p_total] # handle augmented intercept
317
+
318
+ # Equal-size groups: fully vectorized path
319
+ if self._all_equal_size and self._group_size_uniform is not None:
320
+ gs = self._group_size_uniform
321
+ G = self._n_groups
322
+ if self._is_contiguous:
323
+ w_mat = coef_feat.reshape(G, gs)
324
+ else:
325
+ w_mat = coef_feat[self._flat_indices].reshape(G, gs)
326
+
327
+ norms = _vector_norm(w_mat, xp, dim=1)
328
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
329
+
330
+ # Unified path for all backends
331
+ safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
332
+ scale = xp.where(norms > 1e-15,
333
+ self.alpha * sqrt_pg / safe_norms,
334
+ 0.0)
335
+ grad_mat = w_mat * scale[:, None]
336
+ if is_torch or is_cupy:
337
+ grad = xp.zeros_like(coef)
338
+ else:
339
+ grad = np.zeros_like(coef, dtype=float)
340
+ if self._is_contiguous:
341
+ grad[:p_total] = grad_mat.reshape(-1)
342
+ else:
343
+ grad[self._flat_indices] = grad_mat.reshape(-1)
344
+ return grad
345
+
346
+ # Unequal groups: vectorized scale + scatter via _group_feat_idx
347
+ norms = self._batched_group_norms_vec(coef_feat, xp, coef)
348
+ sqrt_pg = self._get_sqrt_pg(xp, coef)
349
+
350
+ # Fused: single scale_g (eliminates separate safe_norms + where)
351
+ safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
352
+ scale_g = xp.where(norms > 1e-15,
353
+ self.alpha * sqrt_pg / safe_norms,
354
+ 0.0)
355
+
356
+ feat_idx = self._get_cached('_group_feat_idx', xp, coef)
357
+ grad = xp.zeros_like(coef)
358
+ grad[:p_total] = scale_g[feat_idx] * coef_feat
359
+ return grad
360
+
361
+ # ----------------------------------------------------------------
362
+ # Proximal operator (block soft-thresholding)
363
+ # ----------------------------------------------------------------
364
+
365
+ def proximal(self, w, step: float, backend: str = "numpy"):
366
+ """Group soft-thresholding: each group is shrunk toward zero.
367
+
368
+ GPU backends use vectorized reshape + axis-norm instead of a per-group
369
+ serial loop, eliminating G× kernel-launch + D2H-sync overhead.
370
+ """
371
+ if self._group_indices is None:
372
+ raise ValueError("groups must be set before calling proximal()")
373
+
374
+ if backend == "cupy":
375
+ import cupy as cp
376
+ return self._proximal_vectorized(w, step, cp)
377
+ elif backend == "torch":
378
+ import torch
379
+ return self._proximal_vectorized(w, step, torch)
380
+ else:
381
+ return self._proximal_loop(w, step, np)
382
+
383
+ def _proximal_loop(self, w, step, xp):
384
+ """Per-group serial loop (numpy CPU path)."""
385
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
386
+ for g, idx in enumerate(self._group_indices):
387
+ w_g = w[idx]
388
+ norm = float(xp.linalg.norm(w_g))
389
+ thresh = self.alpha * self._sqrt_pg[g] * step
390
+ if norm > thresh:
391
+ result[idx] = w_g * (1.0 - thresh / norm)
392
+ else:
393
+ result[idx] = 0.0
394
+ return result
395
+
396
+ def _proximal_vectorized(self, w, step, xp):
397
+ """Vectorized proximal: reshape groups into (G, gs) matrix, compute
398
+ norms in one kernel, scale in one broadcast — O(1) kernel launches.
399
+
400
+ For non-contiguous group layouts, a gather/scatter pass is added.
401
+ """
402
+ G = self._n_groups
403
+
404
+ if self._all_equal_size and self._group_size_uniform is not None:
405
+ gs = self._group_size_uniform
406
+ return self._proximal_equal(w, step, xp, G, gs)
407
+
408
+ # Unequal groups: pad to max size
409
+ max_sz = int(self._group_sizes.max())
410
+ return self._proximal_padded(w, step, xp, G, max_sz)
411
+
412
+ def _gather(self, w, xp):
413
+ """Permute w so groups are contiguous. Only valid for equal-size groups."""
414
+ if not self._all_equal_size:
415
+ raise ValueError("_gather requires equal-size groups; use _proximal_padded instead")
416
+ if self._is_contiguous:
417
+ return w.reshape(self._n_groups, self._group_size_uniform)
418
+ return w[self._flat_indices].reshape(self._n_groups, self._group_size_uniform)
419
+
420
+ def _scatter(self, w_mat_flat, result, xp):
421
+ """Scatter vectorized result back. No-op if already contiguous."""
422
+ if self._is_contiguous:
423
+ result[:] = w_mat_flat
424
+ else:
425
+ result[self._flat_indices] = w_mat_flat
426
+ return result
427
+
428
+ def _get_sqrt_pg(self, xp, w):
429
+ """Cached device tensor for _sqrt_pg."""
430
+ if xp.__name__ == "torch":
431
+ if self._sqrt_pg_torch is None:
432
+ self._sqrt_pg_torch = _to_backend_array(self._sqrt_pg, xp, w)
433
+ return self._sqrt_pg_torch
434
+ else:
435
+ if self._sqrt_pg_cupy is None:
436
+ self._sqrt_pg_cupy = _to_backend_array(self._sqrt_pg, xp, w)
437
+ return self._sqrt_pg_cupy
438
+
439
+ def _get_cached(self, attr_name, xp, w):
440
+ """Get or create cached device tensor for a numpy attribute."""
441
+ backend = "torch" if xp.__name__ == "torch" else "cupy"
442
+ cache_attr = f"_{attr_name}_{backend}"
443
+ cached = getattr(self, cache_attr, None)
444
+ if cached is None:
445
+ cached = _to_backend_array(getattr(self, attr_name), xp, w)
446
+ setattr(self, cache_attr, cached)
447
+ return cached
448
+
449
+ def _get_flat_indices(self, xp, w):
450
+ """Cached device tensor for _flat_indices."""
451
+ if not hasattr(self, '_flat_indices') or self._flat_indices is None:
452
+ return None
453
+ return self._get_cached('_flat_indices', xp, w)
454
+
455
+ def _batched_group_norms_vec(self, coef_feat, xp, w_ref):
456
+ """Vectorized batched group norms using padded fancy indexing.
457
+
458
+ Replaces _batched_group_norms() Python loop with 3 kernels:
459
+ 1. zeros allocation
460
+ 2. fancy index scatter
461
+ 3. vectorized norm along dim=1
462
+ """
463
+ G = self._n_groups
464
+ max_sz = int(self._group_sizes.max())
465
+ padded = _backend_zeros((G, max_sz), xp, dtype=coef_feat.dtype, ref_arr=w_ref)
466
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w_ref)
467
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w_ref)
468
+ if self._is_contiguous:
469
+ padded[row_idx_dev, col_idx_dev] = coef_feat
470
+ else:
471
+ flat_idx_dev = self._get_flat_indices(xp, w_ref)
472
+ padded[row_idx_dev, col_idx_dev] = coef_feat[flat_idx_dev]
473
+ return _vector_norm(padded, xp, dim=1)
474
+
475
+ def _proximal_equal(self, w, step, xp, G, gs):
476
+ """Fast path: all groups equal size, vectorized norm + scale."""
477
+ p_total = G * gs
478
+ w_feat = w[:p_total] # handle augmented intercept
479
+
480
+ # Gather into (G, gs) matrix
481
+ if self._is_contiguous:
482
+ w_mat = w_feat.reshape(G, gs)
483
+ else:
484
+ w_mat = w_feat[self._flat_indices].reshape(G, gs)
485
+
486
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
487
+
488
+ # Torch compiled fast path
489
+ if xp.__name__ == "torch":
490
+ compiled_fn = _get_group_lasso_torch_compiled_equal()
491
+ if compiled_fn is not None:
492
+ scaled_flat = compiled_fn(w_mat, sqrt_pg_arr, self.alpha, step)
493
+ result = w.clone()
494
+ if self._is_contiguous:
495
+ result[:p_total] = scaled_flat
496
+ else:
497
+ result[self._flat_indices] = scaled_flat
498
+ return result
499
+
500
+ # Generic vectorized path
501
+ norms = _vector_norm(w_mat, xp, dim=1)
502
+ thresh = self.alpha * sqrt_pg_arr * step
503
+ scale = xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
504
+ scaled_flat = (w_mat * scale[:, None]).reshape(-1)
505
+
506
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
507
+ if self._is_contiguous:
508
+ result[:p_total] = scaled_flat
509
+ else:
510
+ result[self._flat_indices] = scaled_flat
511
+ return result
512
+
513
+ def _proximal_padded(self, w, step, xp, G, max_sz):
514
+ """General path: pad unequal groups, compute norms vectorized."""
515
+ p_total = int(self._group_sizes.sum())
516
+ w_feat = w[:p_total] # handle augmented intercept
517
+
518
+ # Build padded matrix (G, max_sz) via fancy indexing — 1 kernel launch
519
+ padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
520
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
521
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
522
+ if self._is_contiguous:
523
+ padded[row_idx_dev, col_idx_dev] = w_feat
524
+ else:
525
+ flat_idx_dev = self._get_flat_indices(xp, w)
526
+ padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
527
+
528
+ # Vectorized norms
529
+ norms = _vector_norm(padded, xp, dim=1)
530
+
531
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
532
+ thresh = self.alpha * sqrt_pg_arr * step
533
+ scale = xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
534
+
535
+ # Apply scaling
536
+ padded_scaled = padded * scale[:, None]
537
+
538
+ # Scatter back via fancy indexing — 1 kernel launch
539
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
540
+ if self._is_contiguous:
541
+ result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
542
+ else:
543
+ result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
544
+ return result
545
+
546
+ # ----------------------------------------------------------------
547
+
548
+ def get_params(self) -> dict:
549
+ params = super().get_params()
550
+ params.update({
551
+ "alpha": self.alpha,
552
+ "n_groups": self._n_groups if self._group_indices else 0,
553
+ })
554
+ return params
555
+
556
+
557
+ class AdaptiveGroupLassoPenalty(GroupLassoPenalty):
558
+ """Group Lasso with per-group weights for LLA linearization of group SCAD/MCP.
559
+
560
+ The penalty is:
561
+ P(w) = alpha * sum_g weights_g * sqrt(p_g) * ||w_g||_2
562
+
563
+ where weights_g are per-group LLA weights.
564
+ """
565
+
566
+ name = "adaptive_group_lasso"
567
+
568
+ def __init__(self, groups, alpha=1.0, weights=None):
569
+ super().__init__(alpha=alpha, groups=groups)
570
+ # weights: per-group weight array, shape (n_groups,)
571
+ # None = uniform (same as GroupLasso)
572
+ self._group_weights = weights
573
+
574
+ def set_weights(self, weights):
575
+ """Update per-group weights (numpy array, shape (n_groups,))."""
576
+ self._group_weights = weights
577
+ # Invalidate cached device tensors
578
+ self._group_weights_torch = None
579
+ self._group_weights_cupy = None
580
+
581
+ def _get_group_weights(self, xp, w):
582
+ """Cached device tensor for _group_weights."""
583
+ if self._group_weights is None:
584
+ return None
585
+ if xp.__name__ == "torch":
586
+ if not hasattr(self, '_group_weights_torch') or self._group_weights_torch is None:
587
+ self._group_weights_torch = _to_backend_array(self._group_weights, xp, w)
588
+ return self._group_weights_torch
589
+ else:
590
+ if not hasattr(self, '_group_weights_cupy') or self._group_weights_cupy is None:
591
+ self._group_weights_cupy = _to_backend_array(self._group_weights, xp, w)
592
+ return self._group_weights_cupy
593
+
594
+ def _proximal_loop(self, w, step, xp):
595
+ """Per-group serial loop with per-group weights."""
596
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
597
+ for g, idx in enumerate(self._group_indices):
598
+ w_g = w[idx]
599
+ norm = float(xp.linalg.norm(w_g))
600
+ wg = float(self._group_weights[g]) if self._group_weights is not None else 1.0
601
+ thresh = self.alpha * wg * self._sqrt_pg[g] * step
602
+ if norm > thresh:
603
+ result[idx] = w_g * (1.0 - thresh / norm)
604
+ else:
605
+ result[idx] = 0.0
606
+ return result
607
+
608
+ def _proximal_equal(self, w, step, xp, G, gs):
609
+ """Fast path: all groups equal size, vectorized norm + scale with weights."""
610
+ p_total = G * gs
611
+ w_feat = w[:p_total] # handle augmented intercept
612
+
613
+ if self._is_contiguous:
614
+ w_mat = w_feat.reshape(G, gs)
615
+ else:
616
+ w_mat = w_feat[self._flat_indices].reshape(G, gs)
617
+
618
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
619
+ weights_arr = self._get_group_weights(xp, w)
620
+ if weights_arr is None:
621
+ weights_arr = xp.ones(G, dtype=w.dtype)
622
+ if hasattr(w, 'device'):
623
+ weights_arr = weights_arr.to(device=w.device)
624
+
625
+ norms = _vector_norm(w_mat, xp, dim=1)
626
+ thresh = self.alpha * weights_arr * sqrt_pg_arr * step
627
+ scale = xp.clamp(1.0 - thresh / (norms + 1e-12), 0.0, None) if xp.__name__ == "torch" else xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
628
+ scaled_flat = (w_mat * scale[:, None]).reshape(-1)
629
+
630
+ result = w.clone() if hasattr(w, 'clone') else w.copy()
631
+ if self._is_contiguous:
632
+ result[:p_total] = scaled_flat
633
+ else:
634
+ result[self._flat_indices] = scaled_flat
635
+ return result
636
+
637
+ def _proximal_padded(self, w, step, xp, G, max_sz):
638
+ """General path: pad unequal groups with per-group weights (fancy indexing)."""
639
+ p_total = int(self._group_sizes.sum())
640
+ w_feat = w[:p_total] # handle augmented intercept
641
+
642
+ padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
643
+ row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
644
+ col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
645
+ if self._is_contiguous:
646
+ padded[row_idx_dev, col_idx_dev] = w_feat
647
+ else:
648
+ flat_idx_dev = self._get_flat_indices(xp, w)
649
+ padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
650
+
651
+ norms = _vector_norm(padded, xp, dim=1)
652
+ sqrt_pg_arr = self._get_sqrt_pg(xp, w)
653
+ weights_arr = self._get_group_weights(xp, w)
654
+ if weights_arr is None:
655
+ weights_arr = xp.ones(G, dtype=w.dtype)
656
+ if hasattr(w, 'device'):
657
+ weights_arr = weights_arr.to(device=w.device)
658
+
659
+ thresh = self.alpha * weights_arr * sqrt_pg_arr * step
660
+ if xp.__name__ == "torch":
661
+ scale = xp.clamp(1.0 - thresh / (norms + 1e-12), min=0.0)
662
+ else:
663
+ scale = xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
664
+ padded_scaled = padded * scale[:, None]
665
+
666
+ result = w.copy() if hasattr(w, 'copy') else w.clone()
667
+ if self._is_contiguous:
668
+ result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
669
+ else:
670
+ result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
671
+ return result
672
+
673
+ def get_params(self) -> dict:
674
+ params = super().get_params()
675
+ params.update({
676
+ "weights": self._group_weights,
677
+ })
678
+ return params