statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
statgpu/penalties/_l1.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
L1 penalty (Lasso) implementation.
|
|
3
|
+
|
|
4
|
+
P(w) = α * ||w||₁
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__all__ = ["L1Penalty"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from statgpu.backends._array_ops import _xp
|
|
12
|
+
import numpy as np
|
|
13
|
+
from statgpu.penalties._base import Penalty
|
|
14
|
+
|
|
15
|
+
# ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
|
|
16
|
+
_L1_PROXIMAL_TORCH_COMPILED = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_l1_torch_compiled():
|
|
20
|
+
global _L1_PROXIMAL_TORCH_COMPILED
|
|
21
|
+
if _L1_PROXIMAL_TORCH_COMPILED is not None:
|
|
22
|
+
return _L1_PROXIMAL_TORCH_COMPILED
|
|
23
|
+
from statgpu.penalties import _torch_compile_ok
|
|
24
|
+
if not _torch_compile_ok():
|
|
25
|
+
_L1_PROXIMAL_TORCH_COMPILED = None
|
|
26
|
+
return None
|
|
27
|
+
try:
|
|
28
|
+
import torch
|
|
29
|
+
def _prox(w, thresh):
|
|
30
|
+
return torch.sign(w) * torch.relu(torch.abs(w) - thresh)
|
|
31
|
+
_L1_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, mode='reduce-overhead')
|
|
32
|
+
except Exception:
|
|
33
|
+
_L1_PROXIMAL_TORCH_COMPILED = None
|
|
34
|
+
return _L1_PROXIMAL_TORCH_COMPILED
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class L1Penalty(Penalty):
|
|
38
|
+
"""
|
|
39
|
+
L1 penalty: P(w) = α * ||w||₁
|
|
40
|
+
|
|
41
|
+
The proximal operator is the soft thresholding function:
|
|
42
|
+
prox_{λ·||·||₁}(z) = sign(z) * max(|z| - λ, 0)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
name = "l1"
|
|
46
|
+
is_convex = True
|
|
47
|
+
|
|
48
|
+
def __init__(self, alpha: float = 1.0):
|
|
49
|
+
"""
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
alpha : float, default=1.0
|
|
53
|
+
Regularization strength.
|
|
54
|
+
"""
|
|
55
|
+
if alpha < 0:
|
|
56
|
+
raise ValueError(f"alpha must be non-negative, got {alpha}")
|
|
57
|
+
self.alpha = alpha
|
|
58
|
+
|
|
59
|
+
def value(self, coef):
|
|
60
|
+
"""P(w) = α * Σ|w_j|"""
|
|
61
|
+
xp = _xp(coef)
|
|
62
|
+
return self.alpha * float(xp.sum(xp.abs(coef)))
|
|
63
|
+
|
|
64
|
+
def gradient(self, coef):
|
|
65
|
+
"""∇P(w) = α * sign(w)"""
|
|
66
|
+
xp = _xp(coef)
|
|
67
|
+
return self.alpha * xp.sign(coef)
|
|
68
|
+
|
|
69
|
+
def proximal(
|
|
70
|
+
self,
|
|
71
|
+
w: np.ndarray,
|
|
72
|
+
step: float,
|
|
73
|
+
backend: str = "numpy"
|
|
74
|
+
) -> np.ndarray:
|
|
75
|
+
"""
|
|
76
|
+
Soft thresholding: sign(z) * max(|z| - α*step, 0)
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
w : array
|
|
81
|
+
Input array.
|
|
82
|
+
step : float
|
|
83
|
+
Step size.
|
|
84
|
+
backend : str
|
|
85
|
+
Backend: 'numpy', 'cupy', or 'torch'.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
array
|
|
90
|
+
Soft-thresholded result.
|
|
91
|
+
"""
|
|
92
|
+
thresh = self.alpha * step
|
|
93
|
+
|
|
94
|
+
# torch.compile fast path (performance optimization)
|
|
95
|
+
if backend == "torch":
|
|
96
|
+
compiled_fn = _get_l1_torch_compiled()
|
|
97
|
+
if compiled_fn is not None:
|
|
98
|
+
return compiled_fn(w, thresh)
|
|
99
|
+
|
|
100
|
+
# Unified fallback across numpy/cupy/torch
|
|
101
|
+
from statgpu.backends._array_ops import _soft_threshold
|
|
102
|
+
return _soft_threshold(w, thresh)
|
|
103
|
+
|
|
104
|
+
def get_params(self) -> dict:
|
|
105
|
+
params = super().get_params()
|
|
106
|
+
params["alpha"] = self.alpha
|
|
107
|
+
return params
|
statgpu/penalties/_l2.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
L2 penalty (Ridge) implementation.
|
|
3
|
+
|
|
4
|
+
P(w) = (α/2) * ||w||²₂
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__all__ = ["L2Penalty"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from statgpu.backends._array_ops import _xp
|
|
12
|
+
import numpy as np
|
|
13
|
+
from statgpu.penalties._base import Penalty
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class L2Penalty(Penalty):
|
|
17
|
+
"""
|
|
18
|
+
L2 penalty (Ridge): P(w) = (α/2) * ||w||²₂
|
|
19
|
+
|
|
20
|
+
The proximal operator has a closed-form solution:
|
|
21
|
+
prox_{λ·||·||²/2}(z) = z / (1 + λ*step)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
name = "l2"
|
|
25
|
+
is_convex = True
|
|
26
|
+
|
|
27
|
+
def __init__(self, alpha: float = 1.0):
|
|
28
|
+
"""
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
alpha : float, default=1.0
|
|
32
|
+
Regularization strength.
|
|
33
|
+
"""
|
|
34
|
+
if alpha < 0:
|
|
35
|
+
raise ValueError(f"alpha must be non-negative, got {alpha}")
|
|
36
|
+
self.alpha = alpha
|
|
37
|
+
|
|
38
|
+
def value(self, coef):
|
|
39
|
+
"""P(w) = (α/2) * Σw_j²"""
|
|
40
|
+
xp = _xp(coef)
|
|
41
|
+
return 0.5 * self.alpha * float(xp.sum(coef ** 2))
|
|
42
|
+
|
|
43
|
+
def gradient(self, coef):
|
|
44
|
+
"""∇P(w) = α * w"""
|
|
45
|
+
return self.alpha * coef
|
|
46
|
+
|
|
47
|
+
def proximal(
|
|
48
|
+
self,
|
|
49
|
+
w: np.ndarray,
|
|
50
|
+
step: float,
|
|
51
|
+
backend: str = "numpy"
|
|
52
|
+
) -> np.ndarray:
|
|
53
|
+
"""
|
|
54
|
+
Closed-form for L2: w / (1 + α*step)
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
w : array
|
|
59
|
+
Input array.
|
|
60
|
+
step : float
|
|
61
|
+
Step size.
|
|
62
|
+
backend : str
|
|
63
|
+
Backend: 'numpy', 'cupy', or 'torch'.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
array
|
|
68
|
+
Scaled result.
|
|
69
|
+
"""
|
|
70
|
+
scale = 1.0 / (1.0 + self.alpha * step)
|
|
71
|
+
|
|
72
|
+
return scale * w
|
|
73
|
+
|
|
74
|
+
def get_params(self) -> dict:
|
|
75
|
+
params = super().get_params()
|
|
76
|
+
params["alpha"] = self.alpha
|
|
77
|
+
return params
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP penalty (Minimax Concave Penalty).
|
|
3
|
+
|
|
4
|
+
Zhang, Annals of Statistics 2010. Non-convex penalty with oracle property.
|
|
5
|
+
|
|
6
|
+
Element-wise:
|
|
7
|
+
p(w_j) = {
|
|
8
|
+
alpha * |w_j| - w_j^2 / (2*gamma) if |w_j| <= gamma*alpha
|
|
9
|
+
(1/2) * gamma * alpha^2 if |w_j| > gamma*alpha
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
Supports both FISTA direct (proximal) and LLA (lla_weights) optimization.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
__all__ = ["MCPPenalty"]
|
|
16
|
+
|
|
17
|
+
from typing import Optional
|
|
18
|
+
import numpy as np
|
|
19
|
+
from statgpu.penalties._base import Penalty
|
|
20
|
+
from statgpu.backends._array_ops import _xp
|
|
21
|
+
from statgpu.backends._utils import _to_float_scalar
|
|
22
|
+
|
|
23
|
+
# ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
|
|
24
|
+
_MCP_PROXIMAL_TORCH_COMPILED = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _get_mcp_torch_compiled():
|
|
28
|
+
global _MCP_PROXIMAL_TORCH_COMPILED
|
|
29
|
+
if _MCP_PROXIMAL_TORCH_COMPILED is not None:
|
|
30
|
+
return _MCP_PROXIMAL_TORCH_COMPILED
|
|
31
|
+
from statgpu.penalties import _torch_compile_ok
|
|
32
|
+
if not _torch_compile_ok():
|
|
33
|
+
_MCP_PROXIMAL_TORCH_COMPILED = None
|
|
34
|
+
return None
|
|
35
|
+
try:
|
|
36
|
+
import torch
|
|
37
|
+
def _prox(w, step, alpha, gamma):
|
|
38
|
+
max_step = 0.9 * gamma
|
|
39
|
+
step = torch.clamp(step, max=max_step)
|
|
40
|
+
t = alpha * step
|
|
41
|
+
abs_w = torch.abs(w)
|
|
42
|
+
sign_w = torch.sign(w)
|
|
43
|
+
r1 = abs_w <= t
|
|
44
|
+
r3 = abs_w > gamma * alpha
|
|
45
|
+
r2 = ~(r1 | r3)
|
|
46
|
+
result = torch.where(r1,
|
|
47
|
+
torch.zeros_like(w),
|
|
48
|
+
torch.where(r2,
|
|
49
|
+
sign_w * (abs_w - t) / (1.0 - step / gamma),
|
|
50
|
+
w))
|
|
51
|
+
return result
|
|
52
|
+
_MCP_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, dynamic=True, mode='reduce-overhead')
|
|
53
|
+
except Exception:
|
|
54
|
+
_MCP_PROXIMAL_TORCH_COMPILED = None
|
|
55
|
+
return _MCP_PROXIMAL_TORCH_COMPILED
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MCPPenalty(Penalty):
|
|
59
|
+
"""MCP penalty.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
alpha : float, default=1.0
|
|
64
|
+
Regularization strength.
|
|
65
|
+
gamma : float, default=3.0
|
|
66
|
+
Concavity parameter. Zhang recommends gamma > 1 (default 3.0).
|
|
67
|
+
|
|
68
|
+
Notes
|
|
69
|
+
-----
|
|
70
|
+
MCP is **non-convex** (``is_convex=False``). The objective function may
|
|
71
|
+
contain multiple local minima. Different solvers (e.g. ``fista`` vs
|
|
72
|
+
``fista_bb``) can converge to different local minima with comparable
|
|
73
|
+
objective values — a coefficient ``max|diff|`` up to ~1e-2 is expected
|
|
74
|
+
and does not indicate a bug. The objective values should agree within
|
|
75
|
+
~1e-4 relative tolerance across runs.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
name = "mcp"
|
|
79
|
+
is_convex = False
|
|
80
|
+
|
|
81
|
+
def __init__(self, alpha: float = 1.0, gamma: float = 3.0):
|
|
82
|
+
if not np.isfinite(alpha) or alpha <= 0.0:
|
|
83
|
+
raise ValueError("alpha must be a finite positive scalar for MCP penalty")
|
|
84
|
+
if not np.isfinite(gamma) or gamma <= 1.0:
|
|
85
|
+
raise ValueError("gamma must be a finite scalar greater than 1 for MCP penalty")
|
|
86
|
+
self.alpha = alpha
|
|
87
|
+
self.gamma = gamma
|
|
88
|
+
|
|
89
|
+
# ----------------------------------------------------------------
|
|
90
|
+
# Value
|
|
91
|
+
# ----------------------------------------------------------------
|
|
92
|
+
|
|
93
|
+
def value(self, coef: np.ndarray) -> float:
|
|
94
|
+
xp = _xp(coef)
|
|
95
|
+
alpha = self.alpha
|
|
96
|
+
gamma = self.gamma
|
|
97
|
+
|
|
98
|
+
abs_w = xp.abs(coef)
|
|
99
|
+
region1 = abs_w <= gamma * alpha
|
|
100
|
+
region2 = ~region1
|
|
101
|
+
total = xp.sum(alpha * abs_w[region1] - abs_w[region1] ** 2 / (2.0 * gamma))
|
|
102
|
+
total = total + 0.5 * gamma * alpha ** 2 * xp.sum(region2)
|
|
103
|
+
return _to_float_scalar(total)
|
|
104
|
+
|
|
105
|
+
# ----------------------------------------------------------------
|
|
106
|
+
# Gradient
|
|
107
|
+
# ----------------------------------------------------------------
|
|
108
|
+
|
|
109
|
+
def gradient(self, coef):
|
|
110
|
+
xp = _xp(coef)
|
|
111
|
+
abs_w = xp.abs(coef)
|
|
112
|
+
sign_w = xp.sign(coef)
|
|
113
|
+
alpha = self.alpha
|
|
114
|
+
gamma = self.gamma
|
|
115
|
+
|
|
116
|
+
grad = xp.zeros_like(coef, dtype=coef.dtype if hasattr(coef, 'dtype') else float)
|
|
117
|
+
|
|
118
|
+
mask1 = abs_w <= gamma * alpha
|
|
119
|
+
grad[mask1] = sign_w[mask1] * (alpha - abs_w[mask1] / gamma)
|
|
120
|
+
|
|
121
|
+
return grad
|
|
122
|
+
|
|
123
|
+
# ----------------------------------------------------------------
|
|
124
|
+
# Proximal operator (FISTA direct path)
|
|
125
|
+
# ----------------------------------------------------------------
|
|
126
|
+
|
|
127
|
+
# Lazy-loaded fused CuPy kernel (single launch vs ~10 intermediate arrays)
|
|
128
|
+
_MCP_PROXIMAL_CUPY = None
|
|
129
|
+
|
|
130
|
+
def proximal(
|
|
131
|
+
self,
|
|
132
|
+
w,
|
|
133
|
+
step: float,
|
|
134
|
+
backend: str = "numpy",
|
|
135
|
+
):
|
|
136
|
+
"""Closed-form MCP proximal operator (three regions per coordinate).
|
|
137
|
+
|
|
138
|
+
Clamp step < gamma so the three-region formula always applies.
|
|
139
|
+
"""
|
|
140
|
+
alpha = self.alpha
|
|
141
|
+
gamma = self.gamma
|
|
142
|
+
max_step = 0.9 * gamma
|
|
143
|
+
if step > max_step:
|
|
144
|
+
step = max_step
|
|
145
|
+
t = alpha * step
|
|
146
|
+
|
|
147
|
+
if backend == "cupy":
|
|
148
|
+
import cupy as cp
|
|
149
|
+
if MCPPenalty._MCP_PROXIMAL_CUPY is None:
|
|
150
|
+
MCPPenalty._MCP_PROXIMAL_CUPY = cp.ElementwiseKernel(
|
|
151
|
+
'float64 w, float64 step, float64 alpha, float64 gamma',
|
|
152
|
+
'float64 result',
|
|
153
|
+
'''
|
|
154
|
+
double max_step = 0.9 * gamma;
|
|
155
|
+
double s = (step > max_step) ? max_step : step;
|
|
156
|
+
double abs_w = abs(w);
|
|
157
|
+
double t = alpha * s;
|
|
158
|
+
double sign_w = (w > 0.0) ? 1.0 : ((w < 0.0) ? -1.0 : 0.0);
|
|
159
|
+
if (abs_w <= t) {
|
|
160
|
+
result = 0.0;
|
|
161
|
+
} else if (abs_w <= gamma * alpha) {
|
|
162
|
+
result = sign_w * (abs_w - t) / (1.0 - s / gamma);
|
|
163
|
+
} else {
|
|
164
|
+
result = w;
|
|
165
|
+
}
|
|
166
|
+
''',
|
|
167
|
+
'mcp_proximal',
|
|
168
|
+
)
|
|
169
|
+
return MCPPenalty._MCP_PROXIMAL_CUPY(w, step, alpha, gamma)
|
|
170
|
+
|
|
171
|
+
elif backend == "torch":
|
|
172
|
+
import torch
|
|
173
|
+
compiled_fn = _get_mcp_torch_compiled()
|
|
174
|
+
if compiled_fn is not None:
|
|
175
|
+
step_t = torch.as_tensor(step, dtype=w.dtype, device=w.device)
|
|
176
|
+
return compiled_fn(w, step_t, alpha, gamma)
|
|
177
|
+
abs_w = torch.abs(w)
|
|
178
|
+
sign_w = torch.sign(w)
|
|
179
|
+
|
|
180
|
+
r1 = abs_w <= t
|
|
181
|
+
r3 = abs_w > gamma * alpha
|
|
182
|
+
r2 = ~(r1 | r3)
|
|
183
|
+
result = torch.where(r1,
|
|
184
|
+
torch.zeros_like(w),
|
|
185
|
+
torch.where(r2,
|
|
186
|
+
sign_w * (abs_w - t) / (1.0 - step / gamma),
|
|
187
|
+
w))
|
|
188
|
+
return result
|
|
189
|
+
|
|
190
|
+
else:
|
|
191
|
+
abs_w = np.abs(w)
|
|
192
|
+
sign_w = np.sign(w)
|
|
193
|
+
|
|
194
|
+
region1 = abs_w <= t
|
|
195
|
+
region3 = abs_w > gamma * alpha
|
|
196
|
+
region2 = ~(region1 | region3)
|
|
197
|
+
|
|
198
|
+
result = np.zeros_like(w, dtype=float)
|
|
199
|
+
result[region2] = (
|
|
200
|
+
sign_w[region2]
|
|
201
|
+
* (abs_w[region2] - t)
|
|
202
|
+
/ (1.0 - step / gamma)
|
|
203
|
+
)
|
|
204
|
+
result[region3] = w[region3]
|
|
205
|
+
return result
|
|
206
|
+
|
|
207
|
+
# ----------------------------------------------------------------
|
|
208
|
+
# LLA weights (Local Linear Approximation path)
|
|
209
|
+
# ----------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
def lla_weights(self, coef):
|
|
212
|
+
"""
|
|
213
|
+
LLA weights: w_j = P'(|coef_j|) — the subgradient of MCP at |coef_j|.
|
|
214
|
+
|
|
215
|
+
w_j = {
|
|
216
|
+
alpha - |coef_j| / gamma if |coef_j| <= gamma*alpha
|
|
217
|
+
0 if |coef_j| > gamma*alpha
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
Accepts numpy, cupy, or torch arrays. Returns same backend type.
|
|
221
|
+
"""
|
|
222
|
+
alpha = self.alpha
|
|
223
|
+
gamma = self.gamma
|
|
224
|
+
|
|
225
|
+
xp = _xp(coef)
|
|
226
|
+
abs_w = xp.abs(coef)
|
|
227
|
+
weights = xp.zeros_like(coef)
|
|
228
|
+
mask = abs_w <= gamma * alpha
|
|
229
|
+
weights[mask] = alpha - abs_w[mask] / gamma
|
|
230
|
+
return weights
|
|
231
|
+
|
|
232
|
+
# ----------------------------------------------------------------
|
|
233
|
+
|
|
234
|
+
def get_params(self) -> dict:
|
|
235
|
+
params = super().get_params()
|
|
236
|
+
params.update({"alpha": self.alpha, "gamma": self.gamma})
|
|
237
|
+
return params
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SCAD penalty (Smoothly Clipped Absolute Deviation).
|
|
3
|
+
|
|
4
|
+
Fan & Li, JASA 2001. Non-convex penalty with oracle property.
|
|
5
|
+
|
|
6
|
+
Element-wise:
|
|
7
|
+
p(w_j) = {
|
|
8
|
+
alpha * |w_j| if |w_j| <= alpha
|
|
9
|
+
-(w_j^2 - 2*a*alpha*|w_j| + alpha^2) / (2*(a-1)) if alpha < |w_j| <= a*alpha
|
|
10
|
+
(a+1)*alpha^2 / 2 if |w_j| > a*alpha
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
Supports both FISTA direct (proximal) and LLA (lla_weights) optimization.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
__all__ = ["SCADPenalty"]
|
|
17
|
+
|
|
18
|
+
from typing import Optional
|
|
19
|
+
import numpy as np
|
|
20
|
+
from statgpu.penalties._base import Penalty
|
|
21
|
+
from statgpu.backends._array_ops import _xp
|
|
22
|
+
from statgpu.backends._utils import _to_float_scalar
|
|
23
|
+
|
|
24
|
+
# ---- torch.compile lazy-loader (fuses elementwise ops into 1 kernel) ---------
|
|
25
|
+
_SCAD_PROXIMAL_TORCH_COMPILED = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_scad_torch_compiled():
|
|
29
|
+
global _SCAD_PROXIMAL_TORCH_COMPILED
|
|
30
|
+
if _SCAD_PROXIMAL_TORCH_COMPILED is not None:
|
|
31
|
+
return _SCAD_PROXIMAL_TORCH_COMPILED
|
|
32
|
+
from statgpu.penalties import _torch_compile_ok
|
|
33
|
+
if not _torch_compile_ok():
|
|
34
|
+
_SCAD_PROXIMAL_TORCH_COMPILED = None
|
|
35
|
+
return None
|
|
36
|
+
try:
|
|
37
|
+
import torch
|
|
38
|
+
def _prox(w, step, alpha, a):
|
|
39
|
+
max_step = 0.9 * (a - 1.0)
|
|
40
|
+
step = torch.clamp(step, max=max_step)
|
|
41
|
+
t = alpha * step
|
|
42
|
+
abs_w = torch.abs(w)
|
|
43
|
+
sign_w = torch.sign(w)
|
|
44
|
+
r1 = abs_w <= alpha + t
|
|
45
|
+
r3 = abs_w > a * alpha
|
|
46
|
+
r2 = ~(r1 | r3)
|
|
47
|
+
result = torch.where(r1,
|
|
48
|
+
sign_w * torch.relu(abs_w - t),
|
|
49
|
+
torch.where(r2,
|
|
50
|
+
sign_w * ((a - 1.0) * abs_w - a * t) / (a - 1.0 - step),
|
|
51
|
+
w))
|
|
52
|
+
return result
|
|
53
|
+
_SCAD_PROXIMAL_TORCH_COMPILED = torch.compile(_prox, dynamic=True, mode='reduce-overhead')
|
|
54
|
+
except Exception:
|
|
55
|
+
_SCAD_PROXIMAL_TORCH_COMPILED = None
|
|
56
|
+
return _SCAD_PROXIMAL_TORCH_COMPILED
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SCADPenalty(Penalty):
|
|
60
|
+
"""SCAD penalty.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
alpha : float, default=1.0
|
|
65
|
+
Regularization strength.
|
|
66
|
+
a : float, default=3.7
|
|
67
|
+
Concavity parameter. Fan & Li recommend 3.7.
|
|
68
|
+
|
|
69
|
+
Notes
|
|
70
|
+
-----
|
|
71
|
+
SCAD is **non-convex** (``is_convex=False``). The objective function may
|
|
72
|
+
contain multiple local minima. Different solvers (e.g. ``fista`` vs
|
|
73
|
+
``fista_bb``) can converge to different local minima with comparable
|
|
74
|
+
objective values — a coefficient ``max|diff|`` up to ~1e-2 is expected
|
|
75
|
+
and does not indicate a bug. The objective values should agree within
|
|
76
|
+
~1e-4 relative tolerance across runs.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
name = "scad"
|
|
80
|
+
is_convex = False
|
|
81
|
+
|
|
82
|
+
def __init__(self, alpha: float = 1.0, a: float = 3.7):
|
|
83
|
+
if not np.isfinite(alpha) or alpha <= 0.0:
|
|
84
|
+
raise ValueError("alpha must be a finite positive scalar for SCAD penalty")
|
|
85
|
+
if not np.isfinite(a) or a <= 2.0:
|
|
86
|
+
raise ValueError("a must be a finite scalar greater than 2 for SCAD penalty")
|
|
87
|
+
self.alpha = alpha
|
|
88
|
+
self.a = a
|
|
89
|
+
|
|
90
|
+
# ----------------------------------------------------------------
|
|
91
|
+
# Value
|
|
92
|
+
# ----------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
def value(self, coef: np.ndarray) -> float:
|
|
95
|
+
xp = _xp(coef)
|
|
96
|
+
a = self.a
|
|
97
|
+
alpha = self.alpha
|
|
98
|
+
|
|
99
|
+
abs_w = xp.abs(coef)
|
|
100
|
+
region1 = abs_w <= alpha
|
|
101
|
+
region2 = (abs_w > alpha) & (abs_w <= a * alpha)
|
|
102
|
+
region3 = abs_w > a * alpha
|
|
103
|
+
total = alpha * xp.sum(abs_w[region1])
|
|
104
|
+
total = total + xp.sum(
|
|
105
|
+
-(abs_w[region2] ** 2 - 2 * a * alpha * abs_w[region2] + alpha ** 2)
|
|
106
|
+
/ (2.0 * (a - 1.0))
|
|
107
|
+
)
|
|
108
|
+
total = total + (a + 1.0) * alpha ** 2 / 2.0 * xp.sum(region3)
|
|
109
|
+
return _to_float_scalar(total)
|
|
110
|
+
|
|
111
|
+
# ----------------------------------------------------------------
|
|
112
|
+
# Gradient
|
|
113
|
+
# ----------------------------------------------------------------
|
|
114
|
+
|
|
115
|
+
def gradient(self, coef):
|
|
116
|
+
xp = _xp(coef)
|
|
117
|
+
abs_w = xp.abs(coef)
|
|
118
|
+
sign_w = xp.sign(coef)
|
|
119
|
+
a = self.a
|
|
120
|
+
alpha = self.alpha
|
|
121
|
+
|
|
122
|
+
grad = xp.zeros_like(coef, dtype=coef.dtype if hasattr(coef, 'dtype') else float)
|
|
123
|
+
|
|
124
|
+
# Region 1: |w| <= alpha → alpha * sign(w)
|
|
125
|
+
mask1 = abs_w <= alpha
|
|
126
|
+
grad[mask1] = alpha * sign_w[mask1]
|
|
127
|
+
|
|
128
|
+
# Region 2: alpha < |w| <= a*alpha → (a*alpha*sign - w) / (a-1)
|
|
129
|
+
mask2 = (abs_w > alpha) & (abs_w <= a * alpha)
|
|
130
|
+
grad[mask2] = (a * alpha * sign_w[mask2] - coef[mask2]) / (a - 1.0)
|
|
131
|
+
|
|
132
|
+
# Region 3: |w| > a*alpha → 0
|
|
133
|
+
return grad
|
|
134
|
+
|
|
135
|
+
# ----------------------------------------------------------------
|
|
136
|
+
# Proximal operator (FISTA direct path)
|
|
137
|
+
# ----------------------------------------------------------------
|
|
138
|
+
|
|
139
|
+
# Lazy-loaded fused CuPy kernel (single launch vs ~15 intermediate arrays)
|
|
140
|
+
_SCAD_PROXIMAL_CUPY = None
|
|
141
|
+
|
|
142
|
+
def proximal(
|
|
143
|
+
self,
|
|
144
|
+
w,
|
|
145
|
+
step: float,
|
|
146
|
+
backend: str = "numpy",
|
|
147
|
+
):
|
|
148
|
+
"""Closed-form SCAD proximal operator (three regions per coordinate).
|
|
149
|
+
|
|
150
|
+
When step > a-1 the three-region formula degenerates (division by
|
|
151
|
+
zero or negative denominator). Clamp step so the three-region
|
|
152
|
+
logic always applies — this matches R ncvreg's per-coordinate
|
|
153
|
+
behaviour where each coordinate has its own step v_j and the
|
|
154
|
+
threshold is always alpha (never alpha*v_j).
|
|
155
|
+
"""
|
|
156
|
+
alpha = self.alpha
|
|
157
|
+
a = self.a
|
|
158
|
+
# Clamp step: ensure a > 1 + step (three-region condition).
|
|
159
|
+
# Use 0.9*(a-1) as max to avoid the singularity at step = a-1.
|
|
160
|
+
max_step = 0.9 * (a - 1.0)
|
|
161
|
+
if step > max_step:
|
|
162
|
+
step = max_step
|
|
163
|
+
t = alpha * step
|
|
164
|
+
|
|
165
|
+
if backend == "cupy":
|
|
166
|
+
import cupy as cp
|
|
167
|
+
if SCADPenalty._SCAD_PROXIMAL_CUPY is None:
|
|
168
|
+
SCADPenalty._SCAD_PROXIMAL_CUPY = cp.ElementwiseKernel(
|
|
169
|
+
'float64 w, float64 step, float64 alpha, float64 a',
|
|
170
|
+
'float64 result',
|
|
171
|
+
'''
|
|
172
|
+
double max_step = 0.9 * (a - 1.0);
|
|
173
|
+
double s = (step > max_step) ? max_step : step;
|
|
174
|
+
double abs_w = abs(w);
|
|
175
|
+
double t = alpha * s;
|
|
176
|
+
double sign_w = (w > 0.0) ? 1.0 : ((w < 0.0) ? -1.0 : 0.0);
|
|
177
|
+
if (abs_w <= alpha + t) {
|
|
178
|
+
double v = abs_w - t;
|
|
179
|
+
result = sign_w * (v > 0.0 ? v : 0.0);
|
|
180
|
+
} else if (abs_w <= a * alpha) {
|
|
181
|
+
result = sign_w * ((a - 1.0) * abs_w - a * t) / (a - 1.0 - s);
|
|
182
|
+
} else {
|
|
183
|
+
result = w;
|
|
184
|
+
}
|
|
185
|
+
''',
|
|
186
|
+
'scad_proximal',
|
|
187
|
+
)
|
|
188
|
+
return SCADPenalty._SCAD_PROXIMAL_CUPY(w, step, alpha, a)
|
|
189
|
+
|
|
190
|
+
elif backend == "torch":
|
|
191
|
+
import torch
|
|
192
|
+
compiled_fn = _get_scad_torch_compiled()
|
|
193
|
+
if compiled_fn is not None:
|
|
194
|
+
step_t = torch.as_tensor(step, dtype=w.dtype, device=w.device)
|
|
195
|
+
return compiled_fn(w, step_t, alpha, a)
|
|
196
|
+
abs_w = torch.abs(w)
|
|
197
|
+
sign_w = torch.sign(w)
|
|
198
|
+
|
|
199
|
+
r1 = abs_w <= alpha + t
|
|
200
|
+
r3 = abs_w > a * alpha
|
|
201
|
+
r2 = ~(r1 | r3)
|
|
202
|
+
result = torch.where(r1,
|
|
203
|
+
sign_w * torch.relu(abs_w - t),
|
|
204
|
+
torch.where(r2,
|
|
205
|
+
sign_w * ((a - 1.0) * abs_w - a * t) / (a - 1.0 - step),
|
|
206
|
+
w))
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
else:
|
|
210
|
+
abs_w = np.abs(w)
|
|
211
|
+
sign_w = np.sign(w)
|
|
212
|
+
|
|
213
|
+
region1 = abs_w <= alpha + t
|
|
214
|
+
region3 = abs_w > a * alpha
|
|
215
|
+
region2 = ~(region1 | region3)
|
|
216
|
+
|
|
217
|
+
result = np.zeros_like(w, dtype=float)
|
|
218
|
+
result[region1] = sign_w[region1] * np.maximum(abs_w[region1] - t, 0.0)
|
|
219
|
+
result[region2] = (
|
|
220
|
+
sign_w[region2]
|
|
221
|
+
* ((a - 1.0) * abs_w[region2] - a * t)
|
|
222
|
+
/ (a - 1.0 - step)
|
|
223
|
+
)
|
|
224
|
+
result[region3] = w[region3]
|
|
225
|
+
return result
|
|
226
|
+
|
|
227
|
+
# ----------------------------------------------------------------
|
|
228
|
+
# LLA weights (Local Linear Approximation path)
|
|
229
|
+
# ----------------------------------------------------------------
|
|
230
|
+
|
|
231
|
+
def lla_weights(self, coef):
|
|
232
|
+
"""
|
|
233
|
+
LLA weights: w_j = P'(|coef_j|) — the subgradient of SCAD at |coef_j|.
|
|
234
|
+
|
|
235
|
+
w_j = {
|
|
236
|
+
alpha if |coef_j| <= alpha
|
|
237
|
+
(a*alpha - |coef_j|) / (a - 1) if alpha < |coef_j| <= a*alpha
|
|
238
|
+
0 if |coef_j| > a*alpha
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
Accepts numpy, cupy, or torch arrays. Returns same backend type.
|
|
242
|
+
"""
|
|
243
|
+
a = self.a
|
|
244
|
+
alpha = self.alpha
|
|
245
|
+
|
|
246
|
+
xp = _xp(coef)
|
|
247
|
+
abs_w = xp.abs(coef)
|
|
248
|
+
weights = xp.full_like(coef, alpha)
|
|
249
|
+
mask2 = (abs_w > alpha) & (abs_w <= a * alpha)
|
|
250
|
+
weights[mask2] = (a * alpha - abs_w[mask2]) / (a - 1.0)
|
|
251
|
+
mask3 = abs_w > a * alpha
|
|
252
|
+
weights[mask3] = 0.0
|
|
253
|
+
return weights
|
|
254
|
+
|
|
255
|
+
# ----------------------------------------------------------------
|
|
256
|
+
|
|
257
|
+
def get_params(self) -> dict:
|
|
258
|
+
params = super().get_params()
|
|
259
|
+
params.update({"alpha": self.alpha, "a": self.a})
|
|
260
|
+
return params
|