statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,107 @@
1
+ """
2
+ L1 penalty (Lasso) implementation.
3
+
4
+ P(w) = α * ||w||₁
5
+ """
6
+
7
+ __all__ = ["L1Penalty"]
8
+
9
+
10
+ from typing import Optional
11
+ from statgpu.backends._array_ops import _xp
12
+ import numpy as np
13
+ from statgpu.penalties._base import Penalty
14
+
15
+ # ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
16
+ _L1_PROXIMAL_TORCH_COMPILED = None
17
+
18
+
19
+ def _get_l1_torch_compiled():
20
+ global _L1_PROXIMAL_TORCH_COMPILED
21
+ if _L1_PROXIMAL_TORCH_COMPILED is not None:
22
+ return _L1_PROXIMAL_TORCH_COMPILED
23
+ from statgpu.penalties import _torch_compile_ok
24
+ if not _torch_compile_ok():
25
+ _L1_PROXIMAL_TORCH_COMPILED = None
26
+ return None
27
+ try:
28
+ import torch
29
+ def _prox(w, thresh):
30
+ return torch.sign(w) * torch.relu(torch.abs(w) - thresh)
31
+ _L1_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, mode='reduce-overhead')
32
+ except Exception:
33
+ _L1_PROXIMAL_TORCH_COMPILED = None
34
+ return _L1_PROXIMAL_TORCH_COMPILED
35
+
36
+
37
+ class L1Penalty(Penalty):
38
+ """
39
+ L1 penalty: P(w) = α * ||w||₁
40
+
41
+ The proximal operator is the soft thresholding function:
42
+ prox_{λ·||·||₁}(z) = sign(z) * max(|z| - λ, 0)
43
+ """
44
+
45
+ name = "l1"
46
+ is_convex = True
47
+
48
+ def __init__(self, alpha: float = 1.0):
49
+ """
50
+ Parameters
51
+ ----------
52
+ alpha : float, default=1.0
53
+ Regularization strength.
54
+ """
55
+ if alpha < 0:
56
+ raise ValueError(f"alpha must be non-negative, got {alpha}")
57
+ self.alpha = alpha
58
+
59
+ def value(self, coef):
60
+ """P(w) = α * Σ|w_j|"""
61
+ xp = _xp(coef)
62
+ return self.alpha * float(xp.sum(xp.abs(coef)))
63
+
64
+ def gradient(self, coef):
65
+ """∇P(w) = α * sign(w)"""
66
+ xp = _xp(coef)
67
+ return self.alpha * xp.sign(coef)
68
+
69
+ def proximal(
70
+ self,
71
+ w: np.ndarray,
72
+ step: float,
73
+ backend: str = "numpy"
74
+ ) -> np.ndarray:
75
+ """
76
+ Soft thresholding: sign(z) * max(|z| - α*step, 0)
77
+
78
+ Parameters
79
+ ----------
80
+ w : array
81
+ Input array.
82
+ step : float
83
+ Step size.
84
+ backend : str
85
+ Backend: 'numpy', 'cupy', or 'torch'.
86
+
87
+ Returns
88
+ -------
89
+ array
90
+ Soft-thresholded result.
91
+ """
92
+ thresh = self.alpha * step
93
+
94
+ # torch.compile fast path (performance optimization)
95
+ if backend == "torch":
96
+ compiled_fn = _get_l1_torch_compiled()
97
+ if compiled_fn is not None:
98
+ return compiled_fn(w, thresh)
99
+
100
+ # Unified fallback across numpy/cupy/torch
101
+ from statgpu.backends._array_ops import _soft_threshold
102
+ return _soft_threshold(w, thresh)
103
+
104
+ def get_params(self) -> dict:
105
+ params = super().get_params()
106
+ params["alpha"] = self.alpha
107
+ return params
@@ -0,0 +1,77 @@
1
+ """
2
+ L2 penalty (Ridge) implementation.
3
+
4
+ P(w) = (α/2) * ||w||²₂
5
+ """
6
+
7
+ __all__ = ["L2Penalty"]
8
+
9
+
10
+ from typing import Optional
11
+ from statgpu.backends._array_ops import _xp
12
+ import numpy as np
13
+ from statgpu.penalties._base import Penalty
14
+
15
+
16
+ class L2Penalty(Penalty):
17
+ """
18
+ L2 penalty (Ridge): P(w) = (α/2) * ||w||²₂
19
+
20
+ The proximal operator has a closed-form solution:
21
+ prox_{λ·||·||²/2}(z) = z / (1 + λ*step)
22
+ """
23
+
24
+ name = "l2"
25
+ is_convex = True
26
+
27
+ def __init__(self, alpha: float = 1.0):
28
+ """
29
+ Parameters
30
+ ----------
31
+ alpha : float, default=1.0
32
+ Regularization strength.
33
+ """
34
+ if alpha < 0:
35
+ raise ValueError(f"alpha must be non-negative, got {alpha}")
36
+ self.alpha = alpha
37
+
38
+ def value(self, coef):
39
+ """P(w) = (α/2) * Σw_j²"""
40
+ xp = _xp(coef)
41
+ return 0.5 * self.alpha * float(xp.sum(coef ** 2))
42
+
43
+ def gradient(self, coef):
44
+ """∇P(w) = α * w"""
45
+ return self.alpha * coef
46
+
47
+ def proximal(
48
+ self,
49
+ w: np.ndarray,
50
+ step: float,
51
+ backend: str = "numpy"
52
+ ) -> np.ndarray:
53
+ """
54
+ Closed-form for L2: w / (1 + α*step)
55
+
56
+ Parameters
57
+ ----------
58
+ w : array
59
+ Input array.
60
+ step : float
61
+ Step size.
62
+ backend : str
63
+ Backend: 'numpy', 'cupy', or 'torch'.
64
+
65
+ Returns
66
+ -------
67
+ array
68
+ Scaled result.
69
+ """
70
+ scale = 1.0 / (1.0 + self.alpha * step)
71
+
72
+ return scale * w
73
+
74
+ def get_params(self) -> dict:
75
+ params = super().get_params()
76
+ params["alpha"] = self.alpha
77
+ return params
@@ -0,0 +1,237 @@
1
+ """
2
+ MCP penalty (Minimax Concave Penalty).
3
+
4
+ Zhang, Annals of Statistics 2010. Non-convex penalty with oracle property.
5
+
6
+ Element-wise:
7
+ p(w_j) = {
8
+ alpha * |w_j| - w_j^2 / (2*gamma) if |w_j| <= gamma*alpha
9
+ (1/2) * gamma * alpha^2 if |w_j| > gamma*alpha
10
+ }
11
+
12
+ Supports both FISTA direct (proximal) and LLA (lla_weights) optimization.
13
+ """
14
+
15
+ __all__ = ["MCPPenalty"]
16
+
17
+ from typing import Optional
18
+ import numpy as np
19
+ from statgpu.penalties._base import Penalty
20
+ from statgpu.backends._array_ops import _xp
21
+ from statgpu.backends._utils import _to_float_scalar
22
+
23
+ # ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
24
+ _MCP_PROXIMAL_TORCH_COMPILED = None
25
+
26
+
27
+ def _get_mcp_torch_compiled():
28
+ global _MCP_PROXIMAL_TORCH_COMPILED
29
+ if _MCP_PROXIMAL_TORCH_COMPILED is not None:
30
+ return _MCP_PROXIMAL_TORCH_COMPILED
31
+ from statgpu.penalties import _torch_compile_ok
32
+ if not _torch_compile_ok():
33
+ _MCP_PROXIMAL_TORCH_COMPILED = None
34
+ return None
35
+ try:
36
+ import torch
37
+ def _prox(w, step, alpha, gamma):
38
+ max_step = 0.9 * gamma
39
+ step = torch.clamp(step, max=max_step)
40
+ t = alpha * step
41
+ abs_w = torch.abs(w)
42
+ sign_w = torch.sign(w)
43
+ r1 = abs_w <= t
44
+ r3 = abs_w > gamma * alpha
45
+ r2 = ~(r1 | r3)
46
+ result = torch.where(r1,
47
+ torch.zeros_like(w),
48
+ torch.where(r2,
49
+ sign_w * (abs_w - t) / (1.0 - step / gamma),
50
+ w))
51
+ return result
52
+ _MCP_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, dynamic=True, mode='reduce-overhead')
53
+ except Exception:
54
+ _MCP_PROXIMAL_TORCH_COMPILED = None
55
+ return _MCP_PROXIMAL_TORCH_COMPILED
56
+
57
+
58
+ class MCPPenalty(Penalty):
59
+ """MCP penalty.
60
+
61
+ Parameters
62
+ ----------
63
+ alpha : float, default=1.0
64
+ Regularization strength.
65
+ gamma : float, default=3.0
66
+ Concavity parameter. Zhang recommends gamma > 1 (default 3.0).
67
+
68
+ Notes
69
+ -----
70
+ MCP is **non-convex** (``is_convex=False``). The objective function may
71
+ contain multiple local minima. Different solvers (e.g. ``fista`` vs
72
+ ``fista_bb``) can converge to different local minima with comparable
73
+ objective values — a coefficient ``max|diff|`` up to ~1e-2 is expected
74
+ and does not indicate a bug. The objective values should agree within
75
+ ~1e-4 relative tolerance across runs.
76
+ """
77
+
78
+ name = "mcp"
79
+ is_convex = False
80
+
81
+ def __init__(self, alpha: float = 1.0, gamma: float = 3.0):
82
+ if not np.isfinite(alpha) or alpha <= 0.0:
83
+ raise ValueError("alpha must be a finite positive scalar for MCP penalty")
84
+ if not np.isfinite(gamma) or gamma <= 1.0:
85
+ raise ValueError("gamma must be a finite scalar greater than 1 for MCP penalty")
86
+ self.alpha = alpha
87
+ self.gamma = gamma
88
+
89
+ # ----------------------------------------------------------------
90
+ # Value
91
+ # ----------------------------------------------------------------
92
+
93
+ def value(self, coef: np.ndarray) -> float:
94
+ xp = _xp(coef)
95
+ alpha = self.alpha
96
+ gamma = self.gamma
97
+
98
+ abs_w = xp.abs(coef)
99
+ region1 = abs_w <= gamma * alpha
100
+ region2 = ~region1
101
+ total = xp.sum(alpha * abs_w[region1] - abs_w[region1] ** 2 / (2.0 * gamma))
102
+ total = total + 0.5 * gamma * alpha ** 2 * xp.sum(region2)
103
+ return _to_float_scalar(total)
104
+
105
+ # ----------------------------------------------------------------
106
+ # Gradient
107
+ # ----------------------------------------------------------------
108
+
109
+ def gradient(self, coef):
110
+ xp = _xp(coef)
111
+ abs_w = xp.abs(coef)
112
+ sign_w = xp.sign(coef)
113
+ alpha = self.alpha
114
+ gamma = self.gamma
115
+
116
+ grad = xp.zeros_like(coef, dtype=coef.dtype if hasattr(coef, 'dtype') else float)
117
+
118
+ mask1 = abs_w <= gamma * alpha
119
+ grad[mask1] = sign_w[mask1] * (alpha - abs_w[mask1] / gamma)
120
+
121
+ return grad
122
+
123
+ # ----------------------------------------------------------------
124
+ # Proximal operator (FISTA direct path)
125
+ # ----------------------------------------------------------------
126
+
127
+ # Lazy-loaded fused CuPy kernel (single launch vs ~10 intermediate arrays)
128
+ _MCP_PROXIMAL_CUPY = None
129
+
130
+ def proximal(
131
+ self,
132
+ w,
133
+ step: float,
134
+ backend: str = "numpy",
135
+ ):
136
+ """Closed-form MCP proximal operator (three regions per coordinate).
137
+
138
+ Clamp step < gamma so the three-region formula always applies.
139
+ """
140
+ alpha = self.alpha
141
+ gamma = self.gamma
142
+ max_step = 0.9 * gamma
143
+ if step > max_step:
144
+ step = max_step
145
+ t = alpha * step
146
+
147
+ if backend == "cupy":
148
+ import cupy as cp
149
+ if MCPPenalty._MCP_PROXIMAL_CUPY is None:
150
+ MCPPenalty._MCP_PROXIMAL_CUPY = cp.ElementwiseKernel(
151
+ 'float64 w, float64 step, float64 alpha, float64 gamma',
152
+ 'float64 result',
153
+ '''
154
+ double max_step = 0.9 * gamma;
155
+ double s = (step > max_step) ? max_step : step;
156
+ double abs_w = abs(w);
157
+ double t = alpha * s;
158
+ double sign_w = (w > 0.0) ? 1.0 : ((w < 0.0) ? -1.0 : 0.0);
159
+ if (abs_w <= t) {
160
+ result = 0.0;
161
+ } else if (abs_w <= gamma * alpha) {
162
+ result = sign_w * (abs_w - t) / (1.0 - s / gamma);
163
+ } else {
164
+ result = w;
165
+ }
166
+ ''',
167
+ 'mcp_proximal',
168
+ )
169
+ return MCPPenalty._MCP_PROXIMAL_CUPY(w, step, alpha, gamma)
170
+
171
+ elif backend == "torch":
172
+ import torch
173
+ compiled_fn = _get_mcp_torch_compiled()
174
+ if compiled_fn is not None:
175
+ step_t = torch.as_tensor(step, dtype=w.dtype, device=w.device)
176
+ return compiled_fn(w, step_t, alpha, gamma)
177
+ abs_w = torch.abs(w)
178
+ sign_w = torch.sign(w)
179
+
180
+ r1 = abs_w <= t
181
+ r3 = abs_w > gamma * alpha
182
+ r2 = ~(r1 | r3)
183
+ result = torch.where(r1,
184
+ torch.zeros_like(w),
185
+ torch.where(r2,
186
+ sign_w * (abs_w - t) / (1.0 - step / gamma),
187
+ w))
188
+ return result
189
+
190
+ else:
191
+ abs_w = np.abs(w)
192
+ sign_w = np.sign(w)
193
+
194
+ region1 = abs_w <= t
195
+ region3 = abs_w > gamma * alpha
196
+ region2 = ~(region1 | region3)
197
+
198
+ result = np.zeros_like(w, dtype=float)
199
+ result[region2] = (
200
+ sign_w[region2]
201
+ * (abs_w[region2] - t)
202
+ / (1.0 - step / gamma)
203
+ )
204
+ result[region3] = w[region3]
205
+ return result
206
+
207
+ # ----------------------------------------------------------------
208
+ # LLA weights (Local Linear Approximation path)
209
+ # ----------------------------------------------------------------
210
+
211
+ def lla_weights(self, coef):
212
+ """
213
+ LLA weights: w_j = P'(|coef_j|) — the subgradient of MCP at |coef_j|.
214
+
215
+ w_j = {
216
+ alpha - |coef_j| / gamma if |coef_j| <= gamma*alpha
217
+ 0 if |coef_j| > gamma*alpha
218
+ }
219
+
220
+ Accepts numpy, cupy, or torch arrays. Returns same backend type.
221
+ """
222
+ alpha = self.alpha
223
+ gamma = self.gamma
224
+
225
+ xp = _xp(coef)
226
+ abs_w = xp.abs(coef)
227
+ weights = xp.zeros_like(coef)
228
+ mask = abs_w <= gamma * alpha
229
+ weights[mask] = alpha - abs_w[mask] / gamma
230
+ return weights
231
+
232
+ # ----------------------------------------------------------------
233
+
234
+ def get_params(self) -> dict:
235
+ params = super().get_params()
236
+ params.update({"alpha": self.alpha, "gamma": self.gamma})
237
+ return params
@@ -0,0 +1,260 @@
1
+ """
2
+ SCAD penalty (Smoothly Clipped Absolute Deviation).
3
+
4
+ Fan & Li, JASA 2001. Non-convex penalty with oracle property.
5
+
6
+ Element-wise:
7
+ p(w_j) = {
8
+ alpha * |w_j| if |w_j| <= alpha
9
+ -(w_j^2 - 2*a*alpha*|w_j| + alpha^2) / (2*(a-1)) if alpha < |w_j| <= a*alpha
10
+ (a+1)*alpha^2 / 2 if |w_j| > a*alpha
11
+ }
12
+
13
+ Supports both FISTA direct (proximal) and LLA (lla_weights) optimization.
14
+ """
15
+
16
+ __all__ = ["SCADPenalty"]
17
+
18
+ from typing import Optional
19
+ import numpy as np
20
+ from statgpu.penalties._base import Penalty
21
+ from statgpu.backends._array_ops import _xp
22
+ from statgpu.backends._utils import _to_float_scalar
23
+
24
+ # ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
25
+ _SCAD_PROXIMAL_TORCH_COMPILED = None
26
+
27
+
28
+ def _get_scad_torch_compiled():
29
+ global _SCAD_PROXIMAL_TORCH_COMPILED
30
+ if _SCAD_PROXIMAL_TORCH_COMPILED is not None:
31
+ return _SCAD_PROXIMAL_TORCH_COMPILED
32
+ from statgpu.penalties import _torch_compile_ok
33
+ if not _torch_compile_ok():
34
+ _SCAD_PROXIMAL_TORCH_COMPILED = None
35
+ return None
36
+ try:
37
+ import torch
38
+ def _prox(w, step, alpha, a):
39
+ max_step = 0.9 * (a - 1.0)
40
+ step = torch.clamp(step, max=max_step)
41
+ t = alpha * step
42
+ abs_w = torch.abs(w)
43
+ sign_w = torch.sign(w)
44
+ r1 = abs_w <= alpha + t
45
+ r3 = abs_w > a * alpha
46
+ r2 = ~(r1 | r3)
47
+ result = torch.where(r1,
48
+ sign_w * torch.relu(abs_w - t),
49
+ torch.where(r2,
50
+ sign_w * ((a - 1.0) * abs_w - a * t) / (a - 1.0 - step),
51
+ w))
52
+ return result
53
+ _SCAD_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, dynamic=True, mode='reduce-overhead')
54
+ except Exception:
55
+ _SCAD_PROXIMAL_TORCH_COMPILED = None
56
+ return _SCAD_PROXIMAL_TORCH_COMPILED
57
+
58
+
59
+ class SCADPenalty(Penalty):
60
+ """SCAD penalty.
61
+
62
+ Parameters
63
+ ----------
64
+ alpha : float, default=1.0
65
+ Regularization strength.
66
+ a : float, default=3.7
67
+ Concavity parameter. Fan & Li recommend 3.7.
68
+
69
+ Notes
70
+ -----
71
+ SCAD is **non-convex** (``is_convex=False``). The objective function may
72
+ contain multiple local minima. Different solvers (e.g. ``fista`` vs
73
+ ``fista_bb``) can converge to different local minima with comparable
74
+ objective values — a coefficient ``max|diff|`` up to ~1e-2 is expected
75
+ and does not indicate a bug. The objective values should agree within
76
+ ~1e-4 relative tolerance across runs.
77
+ """
78
+
79
+ name = "scad"
80
+ is_convex = False
81
+
82
+ def __init__(self, alpha: float = 1.0, a: float = 3.7):
83
+ if not np.isfinite(alpha) or alpha <= 0.0:
84
+ raise ValueError("alpha must be a finite positive scalar for SCAD penalty")
85
+ if not np.isfinite(a) or a <= 2.0:
86
+ raise ValueError("a must be a finite scalar greater than 2 for SCAD penalty")
87
+ self.alpha = alpha
88
+ self.a = a
89
+
90
+ # ----------------------------------------------------------------
91
+ # Value
92
+ # ----------------------------------------------------------------
93
+
94
+ def value(self, coef: np.ndarray) -> float:
95
+ xp = _xp(coef)
96
+ a = self.a
97
+ alpha = self.alpha
98
+
99
+ abs_w = xp.abs(coef)
100
+ region1 = abs_w <= alpha
101
+ region2 = (abs_w > alpha) & (abs_w <= a * alpha)
102
+ region3 = abs_w > a * alpha
103
+ total = alpha * xp.sum(abs_w[region1])
104
+ total = total + xp.sum(
105
+ -(abs_w[region2] ** 2 - 2 * a * alpha * abs_w[region2] + alpha ** 2)
106
+ / (2.0 * (a - 1.0))
107
+ )
108
+ total = total + (a + 1.0) * alpha ** 2 / 2.0 * xp.sum(region3)
109
+ return _to_float_scalar(total)
110
+
111
+ # ----------------------------------------------------------------
112
+ # Gradient
113
+ # ----------------------------------------------------------------
114
+
115
+ def gradient(self, coef):
116
+ xp = _xp(coef)
117
+ abs_w = xp.abs(coef)
118
+ sign_w = xp.sign(coef)
119
+ a = self.a
120
+ alpha = self.alpha
121
+
122
+ grad = xp.zeros_like(coef, dtype=coef.dtype if hasattr(coef, 'dtype') else float)
123
+
124
+ # Region 1: |w| <= alpha → alpha * sign(w)
125
+ mask1 = abs_w <= alpha
126
+ grad[mask1] = alpha * sign_w[mask1]
127
+
128
+ # Region 2: alpha < |w| <= a*alpha → (a*alpha*sign - w) / (a-1)
129
+ mask2 = (abs_w > alpha) & (abs_w <= a * alpha)
130
+ grad[mask2] = (a * alpha * sign_w[mask2] - coef[mask2]) / (a - 1.0)
131
+
132
+ # Region 3: |w| > a*alpha → 0
133
+ return grad
134
+
135
+ # ----------------------------------------------------------------
136
+ # Proximal operator (FISTA direct path)
137
+ # ----------------------------------------------------------------
138
+
139
+ # Lazy-loaded fused CuPy kernel (single launch vs ~15 intermediate arrays)
140
+ _SCAD_PROXIMAL_CUPY = None
141
+
142
+ def proximal(
143
+ self,
144
+ w,
145
+ step: float,
146
+ backend: str = "numpy",
147
+ ):
148
+ """Closed-form SCAD proximal operator (three regions per coordinate).
149
+
150
+ When step > a-1 the three-region formula degenerates (division by
151
+ zero or negative denominator). Clamp step so the three-region
152
+ logic always applies — this matches R ncvreg's per-coordinate
153
+ behaviour where each coordinate has its own step v_j and the
154
+ threshold is always alpha (never alpha*v_j).
155
+ """
156
+ alpha = self.alpha
157
+ a = self.a
158
+ # Clamp step: ensure a > 1 + step (three-region condition).
159
+ # Use 0.9*(a-1) as max to avoid the singularity at step = a-1.
160
+ max_step = 0.9 * (a - 1.0)
161
+ if step > max_step:
162
+ step = max_step
163
+ t = alpha * step
164
+
165
+ if backend == "cupy":
166
+ import cupy as cp
167
+ if SCADPenalty._SCAD_PROXIMAL_CUPY is None:
168
+ SCADPenalty._SCAD_PROXIMAL_CUPY = cp.ElementwiseKernel(
169
+ 'float64 w, float64 step, float64 alpha, float64 a',
170
+ 'float64 result',
171
+ '''
172
+ double max_step = 0.9 * (a - 1.0);
173
+ double s = (step > max_step) ? max_step : step;
174
+ double abs_w = abs(w);
175
+ double t = alpha * s;
176
+ double sign_w = (w > 0.0) ? 1.0 : ((w < 0.0) ? -1.0 : 0.0);
177
+ if (abs_w <= alpha + t) {
178
+ double v = abs_w - t;
179
+ result = sign_w * (v > 0.0 ? v : 0.0);
180
+ } else if (abs_w <= a * alpha) {
181
+ result = sign_w * ((a - 1.0) * abs_w - a * t) / (a - 1.0 - s);
182
+ } else {
183
+ result = w;
184
+ }
185
+ ''',
186
+ 'scad_proximal',
187
+ )
188
+ return SCADPenalty._SCAD_PROXIMAL_CUPY(w, step, alpha, a)
189
+
190
+ elif backend == "torch":
191
+ import torch
192
+ compiled_fn = _get_scad_torch_compiled()
193
+ if compiled_fn is not None:
194
+ step_t = torch.as_tensor(step, dtype=w.dtype, device=w.device)
195
+ return compiled_fn(w, step_t, alpha, a)
196
+ abs_w = torch.abs(w)
197
+ sign_w = torch.sign(w)
198
+
199
+ r1 = abs_w <= alpha + t
200
+ r3 = abs_w > a * alpha
201
+ r2 = ~(r1 | r3)
202
+ result = torch.where(r1,
203
+ sign_w * torch.relu(abs_w - t),
204
+ torch.where(r2,
205
+ sign_w * ((a - 1.0) * abs_w - a * t) / (a - 1.0 - step),
206
+ w))
207
+ return result
208
+
209
+ else:
210
+ abs_w = np.abs(w)
211
+ sign_w = np.sign(w)
212
+
213
+ region1 = abs_w <= alpha + t
214
+ region3 = abs_w > a * alpha
215
+ region2 = ~(region1 | region3)
216
+
217
+ result = np.zeros_like(w, dtype=float)
218
+ result[region1] = sign_w[region1] * np.maximum(abs_w[region1] - t, 0.0)
219
+ result[region2] = (
220
+ sign_w[region2]
221
+ * ((a - 1.0) * abs_w[region2] - a * t)
222
+ / (a - 1.0 - step)
223
+ )
224
+ result[region3] = w[region3]
225
+ return result
226
+
227
+ # ----------------------------------------------------------------
228
+ # LLA weights (Local Linear Approximation path)
229
+ # ----------------------------------------------------------------
230
+
231
+ def lla_weights(self, coef):
232
+ """
233
+ LLA weights: w_j = P'(|coef_j|) — the subgradient of SCAD at |coef_j|.
234
+
235
+ w_j = {
236
+ alpha if |coef_j| <= alpha
237
+ (a*alpha - |coef_j|) / (a - 1) if alpha < |coef_j| <= a*alpha
238
+ 0 if |coef_j| > a*alpha
239
+ }
240
+
241
+ Accepts numpy, cupy, or torch arrays. Returns same backend type.
242
+ """
243
+ a = self.a
244
+ alpha = self.alpha
245
+
246
+ xp = _xp(coef)
247
+ abs_w = xp.abs(coef)
248
+ weights = xp.full_like(coef, alpha)
249
+ mask2 = (abs_w > alpha) & (abs_w <= a * alpha)
250
+ weights[mask2] = (a * alpha - abs_w[mask2]) / (a - 1.0)
251
+ mask3 = abs_w > a * alpha
252
+ weights[mask3] = 0.0
253
+ return weights
254
+
255
+ # ----------------------------------------------------------------
256
+
257
+ def get_params(self) -> dict:
258
+ params = super().get_params()
259
+ params.update({"alpha": self.alpha, "a": self.a})
260
+ return params
@@ -0,0 +1,5 @@
1
+ """Semiparametric models with GPU support."""
2
+
3
+ from ._gam import GAM
4
+
5
+ __all__ = ['GAM']