statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,313 @@
1
+ """
2
+ Adaptive L1 penalty (Adaptive Lasso).
3
+
4
+ Zou, JASA 2006. Convex penalty with data-driven per-coordinate weights.
5
+
6
+ The penalty is:
7
+ P(w) = alpha * sum(weights_j * |w_j|)
8
+ where weights_j = 1 / (|init_coef_j| + eps)^nu.
9
+
10
+ The weights are set via set_weights() using an initial OLS or Ridge estimate.
11
+ """
12
+
13
+ __all__ = ["AdaptiveL1Penalty"]
14
+
15
+ from typing import Optional
16
+ import numpy as np
17
+ from statgpu.penalties._base import Penalty
18
+
19
+ # ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
20
+ _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED = None
21
+
22
+
23
+ def _get_adaptive_l1_torch_compiled():
24
+ global _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED
25
+ if _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED is not None:
26
+ return _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED
27
+ from statgpu.penalties import _torch_compile_ok
28
+ if not _torch_compile_ok():
29
+ _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED = None
30
+ return None
31
+ try:
32
+ import torch
33
+ def _prox(w, thresh_tensor):
34
+ return torch.sign(w) * torch.relu(torch.abs(w) - thresh_tensor)
35
+ _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, dynamic=True, mode='reduce-overhead')
36
+ except Exception:
37
+ _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED = None
38
+ return _ADAPTIVE_L1_PROXIMAL_TORCH_COMPILED
39
+
40
+
41
+ class AdaptiveL1Penalty(Penalty):
42
+ """Adaptive L1 penalty (Adaptive Lasso).
43
+
44
+ Parameters
45
+ ----------
46
+ alpha : float, default=1.0
47
+ Regularization strength.
48
+ nu : float, default=1.0
49
+ Exponent for weight computation (1 or 2, per Zou 2006).
50
+ eps : float, default=1e-8
51
+ Small constant to avoid division by zero.
52
+ init_method : str, default='auto'
53
+ Method for initial coefficient estimates:
54
+ - 'auto': OLS if n > p, Ridge otherwise
55
+ - 'ols': forced OLS (errors if p > n)
56
+ - 'ridge': forced Ridge (always works)
57
+ normalize : bool, default=True
58
+ If True, normalize weights by their mean to match R glmnet's
59
+ penalty.factor convention (R normalizes so mean(pf) = 1).
60
+ Set to False to use raw 1/|coef| weights with no normalization.
61
+ weights : array-like, optional
62
+ Pre-computed per-coordinate weights. When provided, ``set_weights``
63
+ is a no-op — the external weights are used as-is. When ``None``,
64
+ weights are computed from an initial fit via ``set_weights``.
65
+
66
+ Notes
67
+ -----
68
+ With fixed weights adaptive_l1 is convex (``is_convex=True``). However,
69
+ when used as the inner solver for non-convex penalties (SCAD, MCP) via
70
+ LLA, the overall optimization is non-convex and may converge to different
71
+ local minima depending on the solver and initialization. In standalone use
72
+ (fixed weights from a pre-fit), results are deterministic and reproducible.
73
+ """
74
+
75
+ name = "adaptive_l1"
76
+ is_convex = True
77
+ requires_init = True
78
+
79
+ def __init__(
80
+ self,
81
+ alpha: float = 1.0,
82
+ nu: float = 1.0,
83
+ eps: float = 1e-8,
84
+ init_method: str = "auto",
85
+ normalize: bool = True,
86
+ weights: Optional[np.ndarray] = None,
87
+ ):
88
+ self.alpha = alpha
89
+ self.nu = nu
90
+ self.eps = eps
91
+ self.init_method = init_method
92
+ self.normalize = normalize
93
+ if weights is not None:
94
+ w = np.asarray(weights, dtype=float)
95
+ self._norm_factor = 1.0
96
+ if self.normalize:
97
+ # Normalize by mean to match R glmnet's penalty.factor convention.
98
+ mean_w = float(np.mean(w))
99
+ if mean_w > 0:
100
+ w = w / mean_w
101
+ self._norm_factor = mean_w
102
+ self._weights = w
103
+ else:
104
+ self._weights = None
105
+
106
+ def set_weights(self, coef: np.ndarray):
107
+ """Compute adaptive weights from initial coefficient estimates.
108
+
109
+ weights_j = 1 / (|coef_j| + eps)^nu
110
+
111
+ If ``weights`` was passed to __init__, the external weights are kept
112
+ and this method is a no-op (normalization is handled in __init__).
113
+
114
+ When ``normalize=True`` (default), weights are divided by their
115
+ mean to match R glmnet's penalty.factor convention (R normalizes
116
+ penalty factors so that mean(pf) = 1 internally).
117
+
118
+ When ``normalize=False``, raw 1/|coef| weights are used (no
119
+ normalization).
120
+ """
121
+ if self._weights is not None:
122
+ return
123
+ # Convert to numpy for weight computation (weights are always stored as numpy)
124
+ from statgpu.backends._utils import _to_numpy
125
+ coef_np = np.asarray(_to_numpy(coef), dtype=np.float64).ravel()
126
+ # If the init coef is all-zero (e.g., ridge init diverged),
127
+ # fall back to uniform weights so adaptive_l1 reduces to L1.
128
+ if not np.any(np.abs(coef_np) > 1e-12):
129
+ self._weights = np.ones_like(coef_np)
130
+ self._norm_factor = 1.0
131
+ return
132
+ raw = 1.0 / (np.abs(coef_np) + self.eps) ** self.nu
133
+ self._norm_factor = 1.0
134
+ if self.normalize:
135
+ mean_w = float(np.mean(raw))
136
+ if mean_w > 0:
137
+ raw = raw / mean_w
138
+ self._norm_factor = mean_w
139
+ self._weights = raw
140
+ # Invalidate cached device tensors so proximal recomputes them.
141
+ for _k in ('_alpha_w_torch', '_alpha_w_cupy',
142
+ '_alpha_w_torch_src', '_alpha_w_cupy_src'):
143
+ if hasattr(self, _k):
144
+ delattr(self, _k)
145
+
146
+ # ----------------------------------------------------------------
147
+ # Value
148
+ # ----------------------------------------------------------------
149
+
150
+ def value(self, coef) -> float:
151
+ if not hasattr(self, "_weights"):
152
+ self._weights = np.ones_like(np.asarray(coef))
153
+ mod = type(coef).__module__
154
+ if mod.startswith("torch"):
155
+ import torch
156
+ # Reuse cached device tensor from proximal() if available
157
+ _cached = getattr(self, '_alpha_w_torch', None)
158
+ _src = getattr(self, '_alpha_w_torch_src', None)
159
+ if _cached is not None and _src is self._weights:
160
+ return (_cached * torch.abs(coef)).sum().item()
161
+ w = self._weights
162
+ _is_dev = type(w).__module__.startswith("torch")
163
+ if _is_dev:
164
+ if _cached is None or _src is not w:
165
+ _cached = self.alpha * w.to(device=coef.device, dtype=torch.float64)
166
+ self._alpha_w_torch = _cached
167
+ self._alpha_w_torch_src = w
168
+ else:
169
+ if _cached is None or _src is not w:
170
+ _cached = torch.tensor(self.alpha * np.asarray(w, dtype=float),
171
+ device=coef.device, dtype=torch.float64)
172
+ self._alpha_w_torch = _cached
173
+ self._alpha_w_torch_src = w
174
+ return (_cached * torch.abs(coef)).sum().item()
175
+ elif mod.startswith("cupy"):
176
+ import cupy as cp
177
+ _cached = getattr(self, '_alpha_w_cupy', None)
178
+ _src = getattr(self, '_alpha_w_cupy_src', None)
179
+ w = self._weights
180
+ _is_dev = type(w).__module__.startswith("cupy")
181
+ if _is_dev:
182
+ if _cached is None or _src is not w:
183
+ _cached = self.alpha * w
184
+ self._alpha_w_cupy = _cached
185
+ self._alpha_w_cupy_src = w
186
+ else:
187
+ if _cached is None or _src is not w:
188
+ _cached = cp.asarray(self.alpha * np.asarray(w, dtype=float))
189
+ self._alpha_w_cupy = _cached
190
+ self._alpha_w_cupy_src = w
191
+ return float((_cached * cp.abs(coef)).sum())
192
+ else:
193
+ return self.alpha * np.sum(self._weights * np.abs(coef))
194
+
195
+ # ----------------------------------------------------------------
196
+ # Gradient
197
+ # ----------------------------------------------------------------
198
+
199
+ def gradient(self, coef):
200
+ xp = _xp(coef)
201
+ if not hasattr(self, "_weights"):
202
+ self._weights = xp.ones_like(coef)
203
+ return self.alpha * self._weights * xp.sign(coef)
204
+
205
+ # ----------------------------------------------------------------
206
+ # Proximal operator (FISTA path)
207
+ # ----------------------------------------------------------------
208
+
209
+ # Lazy-loaded fused CuPy kernel
210
+ _ADAPTIVE_L1_PROXIMAL_CUPY = None
211
+
212
+ def proximal(
213
+ self,
214
+ w,
215
+ step: float,
216
+ backend: str = "numpy",
217
+ ):
218
+ """Per-coordinate soft-threshold with per-coordinate thresholds."""
219
+ if not hasattr(self, "_weights"):
220
+ self._weights = np.ones_like(np.asarray(w))
221
+
222
+ # Check if _weights is already a device tensor (from lla_weights on GPU)
223
+ _w_mod = type(self._weights).__module__
224
+ _is_device = _w_mod.startswith("torch") or _w_mod.startswith("cupy")
225
+
226
+ if backend == "cupy":
227
+ import cupy as cp
228
+ if AdaptiveL1Penalty._ADAPTIVE_L1_PROXIMAL_CUPY is None:
229
+ AdaptiveL1Penalty._ADAPTIVE_L1_PROXIMAL_CUPY = cp.ElementwiseKernel(
230
+ 'float64 w, float64 thresh',
231
+ 'float64 result',
232
+ '''
233
+ double abs_w = abs(w);
234
+ double sign_w = (w > 0.0) ? 1.0 : ((w < 0.0) ? -1.0 : 0.0);
235
+ if (abs_w > thresh) {
236
+ result = sign_w * (abs_w - thresh);
237
+ } else {
238
+ result = 0.0;
239
+ }
240
+ ''',
241
+ 'adaptive_l1_proximal',
242
+ )
243
+ # Cache device tensor for alpha*weights across calls.
244
+ # Use _weights_src_id to detect when _weights is reassigned externally.
245
+ _cache_key = '_alpha_w_cupy'
246
+ _src_key = '_alpha_w_cupy_src'
247
+ _cached = getattr(self, _cache_key, None)
248
+ _src = getattr(self, _src_key, None)
249
+ if _is_device:
250
+ if _cached is None or _src is not self._weights:
251
+ _cached = self.alpha * self._weights
252
+ setattr(self, _cache_key, _cached)
253
+ setattr(self, _src_key, self._weights)
254
+ else:
255
+ if _cached is None or _src is not self._weights:
256
+ alpha_w = self.alpha * np.asarray(self._weights, dtype=float)
257
+ _cached = cp.asarray(alpha_w)
258
+ setattr(self, _cache_key, _cached)
259
+ setattr(self, _src_key, self._weights)
260
+ thresh_gpu = _cached * step
261
+ return AdaptiveL1Penalty._ADAPTIVE_L1_PROXIMAL_CUPY(w, thresh_gpu)
262
+ elif backend == "torch":
263
+ import torch
264
+ # Cache device tensor for alpha*weights across calls.
265
+ _cache_key = '_alpha_w_torch'
266
+ _src_key = '_alpha_w_torch_src'
267
+ _cached = getattr(self, _cache_key, None)
268
+ _src = getattr(self, _src_key, None)
269
+ if _is_device:
270
+ if _cached is None or _src is not self._weights:
271
+ _cached = self.alpha * self._weights.to(device=w.device, dtype=torch.float64)
272
+ setattr(self, _cache_key, _cached)
273
+ setattr(self, _src_key, self._weights)
274
+ else:
275
+ if _cached is None or _src is not self._weights:
276
+ alpha_w = self.alpha * np.asarray(self._weights, dtype=float)
277
+ _cached = torch.tensor(alpha_w, device=w.device, dtype=torch.float64)
278
+ setattr(self, _cache_key, _cached)
279
+ setattr(self, _src_key, self._weights)
280
+ thresh_t = _cached * step
281
+ compiled_fn = _get_adaptive_l1_torch_compiled()
282
+ if compiled_fn is not None:
283
+ return compiled_fn(w, thresh_t)
284
+ return torch.sign(w) * torch.relu(torch.abs(w) - thresh_t)
285
+ else:
286
+ alpha_w = self.alpha * np.asarray(self._weights, dtype=float)
287
+ thresh_arr = alpha_w * step
288
+ return np.sign(w) * np.maximum(np.abs(w) - thresh_arr, 0.0)
289
+
290
+ # ----------------------------------------------------------------
291
+ # LLA weights (identity: this is already a weighted L1 penalty)
292
+ # ----------------------------------------------------------------
293
+
294
+ def lla_weights(self, coef):
295
+ """Return LLA weights, converted to the same backend as coef."""
296
+ if not hasattr(self, "_weights"):
297
+ self._weights = np.ones_like(np.asarray(coef))
298
+ # Convert weights to the same backend as coef to avoid device-to-host transfer
299
+ from statgpu.backends._array_ops import _xp
300
+ xp = _xp(coef)
301
+ if xp is np:
302
+ return self._weights.copy()
303
+ return xp.asarray(self._weights, dtype=coef.dtype)
304
+
305
+ # ----------------------------------------------------------------
306
+
307
+ def get_params(self) -> dict:
308
+ params = super().get_params()
309
+ params.update({
310
+ "alpha": self.alpha,
311
+ "nu": self.nu,
312
+ })
313
+ return params
@@ -0,0 +1,261 @@
1
+ """
2
+ Base class for all penalty functions in statgpu.
3
+
4
+ The penalty framework supports:
5
+ - Convex penalties (L1, L2, Elastic Net)
6
+ - Non-convex penalties (SCAD, MCP) via LLA approximation
7
+ - Group penalties (Group Lasso, Sparse Group Lasso)
8
+ - Adaptive/weighted penalties
9
+ """
10
+
11
+ __all__ = ["Penalty", "CompositePenalty"]
12
+
13
+
14
+ from abc import ABC, abstractmethod
15
+ from typing import Optional, Union, Any
16
+ import numpy as np
17
+
18
+ from statgpu.backends._array_ops import _xp
19
+
20
+
21
+ class Penalty(ABC):
22
+ """
23
+ Abstract base class for regularization penalties.
24
+
25
+ A penalty function P(w) defines the regularization term in penalized
26
+ regression:
27
+
28
+ minimize: (1/(2n)) * ||y - Xw||²₂ + P(w)
29
+
30
+ Subclasses must implement:
31
+ - value(coef): Compute P(w)
32
+ - gradient(coef): Compute ∇P(w)
33
+ - proximal(w, step, backend): Compute proximal operator
34
+
35
+ For non-convex penalties (SCAD, MCP), also implement:
36
+ - lla_weights(coef): LLA weights for local linear approximation
37
+ """
38
+
39
+ name: str = "base"
40
+ is_convex: bool = True
41
+ supports_group: bool = False
42
+ requires_init: bool = False
43
+
44
+ @abstractmethod
45
+ def value(self, coef: np.ndarray) -> float:
46
+ """
47
+ Compute penalty value P(w).
48
+
49
+ Parameters
50
+ ----------
51
+ coef : np.ndarray
52
+ Coefficient vector.
53
+
54
+ Returns
55
+ -------
56
+ float
57
+ Penalty value.
58
+ """
59
+ pass
60
+
61
+ @abstractmethod
62
+ def gradient(self, coef: np.ndarray) -> np.ndarray:
63
+ """
64
+ Compute penalty gradient ∇P(w).
65
+
66
+ Parameters
67
+ ----------
68
+ coef : np.ndarray
69
+ Coefficient vector.
70
+
71
+ Returns
72
+ -------
73
+ np.ndarray
74
+ Gradient of penalty at coef.
75
+ """
76
+ pass
77
+
78
+ def proximal(
79
+ self,
80
+ w: np.ndarray,
81
+ step: float,
82
+ backend: str = "numpy"
83
+ ) -> np.ndarray:
84
+ """
85
+ Proximal operator: argmin_z { (1/2)||z - w||² + step * P(z) }
86
+
87
+ Default implementation uses soft thresholding for L1-type penalties.
88
+ Override for group penalties or non-convex penalties.
89
+
90
+ Parameters
91
+ ----------
92
+ w : np.ndarray
93
+ Input array (pre-proximal update).
94
+ step : float
95
+ Step size (typically 1/Lipschitz constant).
96
+ backend : str, default='numpy'
97
+ Backend: 'numpy', 'cupy', or 'torch'.
98
+
99
+ Returns
100
+ -------
101
+ np.ndarray
102
+ Result of proximal operator.
103
+ """
104
+ raise NotImplementedError(
105
+ f"proximal() not implemented for {self.name}. "
106
+ "Subclass must implement this method."
107
+ )
108
+
109
+ def lla_weights(self, coef: np.ndarray) -> np.ndarray:
110
+ """
111
+ Local Linear Approximation (LLA) weights for non-convex penalties.
112
+
113
+ For a penalty P(w), the LLA approximates:
114
+ P(w) ≈ P(coef) + Σ w_j * |w_j - coef_j|
115
+
116
+ where w_j = P'(|coef_j|) for coef_j ≠ 0.
117
+
118
+ This is used to solve non-convex penalties via iteratively
119
+ reweighted L1.
120
+
121
+ Parameters
122
+ ----------
123
+ coef : np.ndarray
124
+ Current coefficient estimate.
125
+
126
+ Returns
127
+ -------
128
+ array
129
+ LLA weights (default: ones for convex L1).
130
+ """
131
+ xp = _xp(coef)
132
+ return xp.ones_like(coef)
133
+
134
+ def get_params(self) -> dict:
135
+ """
136
+ Get penalty parameters for serialization.
137
+
138
+ Returns
139
+ -------
140
+ dict
141
+ Dictionary of penalty parameters.
142
+ """
143
+ return {"name": self.name}
144
+
145
+ def _check_coef_shape(self, coef: np.ndarray) -> None:
146
+ """Validate coefficient array shape."""
147
+ if coef.ndim != 1:
148
+ raise ValueError(f"coef must be 1D, got shape {coef.shape}")
149
+
150
+ def __repr__(self) -> str:
151
+ params = self.get_params()
152
+ param_str = ", ".join(f"{k}={v}" for k, v in params.items())
153
+ return f"{self.__class__.__name__}({param_str})"
154
+
155
+
156
+ class CompositePenalty(Penalty):
157
+ """
158
+ Composite penalty combining multiple penalties.
159
+
160
+ P(w) = Σ weight_i * P_i(w)
161
+
162
+ This allows combining different penalty types, e.g.:
163
+ - Group Lasso + L1 (Sparse Group Lasso)
164
+ - Group Lasso + SCAD
165
+ """
166
+
167
+ name = "composite"
168
+ is_convex = True # Only if all component penalties are convex
169
+
170
+ def __init__(
171
+ self,
172
+ penalties: list,
173
+ weights: Optional[list] = None,
174
+ ):
175
+ """
176
+ Parameters
177
+ ----------
178
+ penalties : list of Penalty
179
+ List of penalty objects.
180
+ weights : list of float, optional
181
+ Weight for each penalty. Default: equal weights.
182
+ """
183
+ self.penalties = penalties
184
+ self.n_penalties = len(penalties)
185
+
186
+ if weights is None:
187
+ self.weights = [1.0 / self.n_penalties] * self.n_penalties
188
+ else:
189
+ if len(weights) != self.n_penalties:
190
+ raise ValueError(
191
+ f"weights must have length {self.n_penalties}, "
192
+ f"got {len(weights)}"
193
+ )
194
+ self.weights = weights
195
+
196
+ # Composite is convex only if all components are convex
197
+ self.is_convex = all(p.is_convex for p in penalties)
198
+
199
+ # Composite requires init if any component requires it
200
+ self.requires_init = any(p.requires_init for p in penalties)
201
+
202
+ def value(self, coef: np.ndarray) -> float:
203
+ """Sum of weighted penalty values."""
204
+ total = 0.0
205
+ for w, pen in zip(self.weights, self.penalties):
206
+ total += w * pen.value(coef)
207
+ return total
208
+
209
+ def gradient(self, coef):
210
+ """Sum of weighted penalty gradients."""
211
+ xp = _xp(coef)
212
+ total = xp.zeros_like(coef)
213
+ for w, pen in zip(self.weights, self.penalties):
214
+ total = total + w * pen.gradient(coef)
215
+ return total
216
+
217
+ def proximal(
218
+ self,
219
+ w: np.ndarray,
220
+ step: float,
221
+ backend: str = "numpy"
222
+ ) -> np.ndarray:
223
+ """
224
+ Proximal for composite penalty.
225
+
226
+ Note: This is an approximation. The exact proximal for a sum
227
+ of penalties is not the composition of individual proximals
228
+ (unless they commute). For most practical cases (e.g., sparse
229
+ group lasso), this approximation works well.
230
+ """
231
+ # Sequential application of proximal operators
232
+ # (Dykstra-like splitting, simplified)
233
+ result = w.clone() if hasattr(w, 'clone') else w.copy()
234
+ for weight, pen in zip(self.weights, self.penalties):
235
+ result = pen.proximal(result, step * weight, backend)
236
+ return result
237
+
238
+ def lla_weights(self, coef):
239
+ """LLA weights: weighted sum of individual LLA weights.
240
+
241
+ For composite penalty P(w) = sum_i w_i * P_i(w),
242
+ the LLA weight is sum_i w_i * P_i'(|coef|).
243
+ """
244
+ xp = _xp(coef)
245
+ if not any(not p.is_convex for p in self.penalties):
246
+ return xp.ones_like(coef)
247
+
248
+ result = xp.zeros_like(coef)
249
+ for weight, pen in zip(self.weights, self.penalties):
250
+ if not pen.is_convex:
251
+ result = result + weight * pen.lla_weights(coef)
252
+ return result
253
+
254
+ def get_params(self) -> dict:
255
+ params = {
256
+ "name": "composite",
257
+ "n_penalties": self.n_penalties,
258
+ "penalties": [p.name for p in self.penalties],
259
+ "weights": self.weights,
260
+ }
261
+ return params
@@ -0,0 +1,39 @@
1
+ """Shared penalty category constants.
2
+
3
+ Single source of truth for penalty name sets used across solver and model layers.
4
+ Adding a new penalty type only requires updating this file.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ # Smooth penalties (differentiable, no proximal needed)
10
+ SMOOTH_PENALTIES = frozenset({"none", "null", "l2"})
11
+
12
+ # Non-smooth but convex penalties (need proximal operator)
13
+ NONSMOOTH_CONVEX = frozenset({
14
+ "l1", "elasticnet", "en", "adaptive_l1", "adaptive_lasso",
15
+ "group_lasso", "gl",
16
+ })
17
+
18
+ # Non-convex penalties (need LLA or specialized solver)
19
+ NONCONVEX = frozenset({
20
+ "scad", "mcp", "group_mcp", "gmcp", "group_scad", "gscad",
21
+ })
22
+
23
+ # All non-smooth penalties (convex + non-convex)
24
+ NONSMOOTH = NONSMOOTH_CONVEX | NONCONVEX
25
+
26
+ # All sparse penalties (L1-type, produce sparse solutions)
27
+ SPARSE = frozenset({
28
+ "l1", "elasticnet", "en", "adaptive_l1", "adaptive_lasso",
29
+ "scad", "mcp",
30
+ })
31
+
32
+ # Group penalties
33
+ GROUP = frozenset({
34
+ "group_lasso", "gl", "group_mcp", "gmcp", "group_scad", "gscad",
35
+ })
36
+
37
+ # Penalties that disable BB step (use standard FISTA instead)
38
+ # Same as GROUP: BB step doesn't work well with group structure
39
+ BB_DISABLED = GROUP