statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Group MCP penalty.
|
|
3
|
+
|
|
4
|
+
Breheny & Huang 2009 (grpreg). Non-convex group penalty: applies MCP
|
|
5
|
+
concavity to the L2 norm of each feature group.
|
|
6
|
+
|
|
7
|
+
Penalty:
|
|
8
|
+
P(w) = sum_g MCP(||w_g||_2; alpha * sqrt(p_g), gamma)
|
|
9
|
+
|
|
10
|
+
where MCP(t; lambda, gamma) is the element-wise MCP penalty.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
__all__ = ["GroupMCPPenalty"]
|
|
14
|
+
|
|
15
|
+
from typing import Optional, List, Union
|
|
16
|
+
import numpy as np
|
|
17
|
+
from statgpu.penalties._base import Penalty
|
|
18
|
+
from statgpu.penalties._group_lasso import _vector_norm, _to_backend_array, _backend_zeros, _batched_group_norms, _get_xp
|
|
19
|
+
|
|
20
|
+
# ---- torch.compile lazy-loader for vectorized MCP proximal ---------
|
|
21
|
+
_GROUP_MCP_PROXIMAL_TORCH_COMPILED = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_group_mcp_torch_compiled():
|
|
25
|
+
global _GROUP_MCP_PROXIMAL_TORCH_COMPILED
|
|
26
|
+
if _GROUP_MCP_PROXIMAL_TORCH_COMPILED is not None:
|
|
27
|
+
return _GROUP_MCP_PROXIMAL_TORCH_COMPILED
|
|
28
|
+
from statgpu.penalties import _torch_compile_ok
|
|
29
|
+
if not _torch_compile_ok():
|
|
30
|
+
_GROUP_MCP_PROXIMAL_TORCH_COMPILED = None
|
|
31
|
+
return None
|
|
32
|
+
try:
|
|
33
|
+
import torch
|
|
34
|
+
def _prox(w_mat, sqrt_pg, alpha, step, gamma):
|
|
35
|
+
t_g = alpha * sqrt_pg * step
|
|
36
|
+
gamma_alpha_g = gamma * alpha * sqrt_pg
|
|
37
|
+
norms = torch.linalg.norm(w_mat, dim=1)
|
|
38
|
+
mask_zero = norms <= t_g
|
|
39
|
+
mask_shrink = (norms > t_g) & (norms <= gamma_alpha_g)
|
|
40
|
+
denom = norms * (1.0 - step / gamma)
|
|
41
|
+
denom = torch.where(mask_shrink, denom, torch.ones_like(denom))
|
|
42
|
+
scale_shrink = (norms - t_g) / denom
|
|
43
|
+
scale = torch.where(mask_shrink, scale_shrink, 1.0)
|
|
44
|
+
scale = torch.where(mask_zero, 0.0, scale)
|
|
45
|
+
return (w_mat * scale[:, None]).reshape(-1)
|
|
46
|
+
_GROUP_MCP_PROXIMAL_TORCH_COMPILED = torch.compile(
|
|
47
|
+
_prox, dynamic=True, mode='reduce-overhead'
|
|
48
|
+
)
|
|
49
|
+
except Exception:
|
|
50
|
+
_GROUP_MCP_PROXIMAL_TORCH_COMPILED = None
|
|
51
|
+
return _GROUP_MCP_PROXIMAL_TORCH_COMPILED
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class GroupMCPPenalty(Penalty):
|
|
55
|
+
"""Group MCP penalty.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
alpha : float, default=1.0
|
|
60
|
+
Regularization strength.
|
|
61
|
+
gamma : float, default=3.0
|
|
62
|
+
MCP concavity parameter. Larger gamma gives less bias (closer to
|
|
63
|
+
group lasso). Must be > 1.
|
|
64
|
+
groups : list of lists, or 1D array-like
|
|
65
|
+
Group membership specification.
|
|
66
|
+
|
|
67
|
+
Notes
|
|
68
|
+
-----
|
|
69
|
+
Group MCP is **non-convex** (``is_convex=False``), optimized via LLA
|
|
70
|
+
(Local Linear Approximation). The objective function may contain multiple
|
|
71
|
+
local minima. Different solvers or different initializations can converge
|
|
72
|
+
to different local minima with comparable objective values — a coefficient
|
|
73
|
+
``max|diff|`` up to ~1e-2 across runs is expected and does not indicate a
|
|
74
|
+
bug.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
name = "group_mcp"
|
|
78
|
+
is_convex = False
|
|
79
|
+
supports_group = True
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
alpha: float = 1.0,
|
|
84
|
+
gamma: float = 3.0,
|
|
85
|
+
groups=None,
|
|
86
|
+
):
|
|
87
|
+
if not np.isfinite(alpha) or alpha <= 0.0:
|
|
88
|
+
raise ValueError("alpha must be a finite positive scalar for group MCP penalty")
|
|
89
|
+
if not np.isfinite(gamma) or gamma <= 1.0:
|
|
90
|
+
raise ValueError("gamma must be a finite scalar greater than 1 for group MCP penalty")
|
|
91
|
+
self.alpha = alpha
|
|
92
|
+
self.gamma = gamma
|
|
93
|
+
self._group_indices = None
|
|
94
|
+
self._sqrt_pg = None
|
|
95
|
+
self._n_groups = 0
|
|
96
|
+
self._all_equal_size = False
|
|
97
|
+
self._is_contiguous = False
|
|
98
|
+
self._group_size_uniform = None
|
|
99
|
+
self._flat_indices = None
|
|
100
|
+
|
|
101
|
+
if groups is not None:
|
|
102
|
+
self._init_groups(groups)
|
|
103
|
+
|
|
104
|
+
def _init_groups(self, groups):
|
|
105
|
+
"""Parse group specification into internal format."""
|
|
106
|
+
if isinstance(groups, np.ndarray) and groups.ndim == 1:
|
|
107
|
+
group_ids = np.asarray(groups, dtype=int)
|
|
108
|
+
n_groups = int(group_ids.max() + 1)
|
|
109
|
+
self._group_indices = [
|
|
110
|
+
np.where(group_ids == g)[0] for g in range(n_groups)
|
|
111
|
+
]
|
|
112
|
+
elif isinstance(groups, (list, tuple)):
|
|
113
|
+
if len(groups) == 0:
|
|
114
|
+
raise ValueError("groups must not be empty")
|
|
115
|
+
if isinstance(groups[0], (list, tuple, np.ndarray)):
|
|
116
|
+
self._group_indices = [
|
|
117
|
+
np.asarray(g, dtype=int) for g in groups
|
|
118
|
+
]
|
|
119
|
+
else:
|
|
120
|
+
group_ids = np.asarray(groups, dtype=int)
|
|
121
|
+
n_groups = int(group_ids.max() + 1)
|
|
122
|
+
self._group_indices = [
|
|
123
|
+
np.where(group_ids == g)[0] for g in range(n_groups)
|
|
124
|
+
]
|
|
125
|
+
else:
|
|
126
|
+
raise TypeError(
|
|
127
|
+
f"groups must be list or array, got {type(groups).__name__}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
self._group_sizes = np.array(
|
|
131
|
+
[len(g) for g in self._group_indices], dtype=int
|
|
132
|
+
)
|
|
133
|
+
self._sqrt_pg = np.sqrt(self._group_sizes.astype(float))
|
|
134
|
+
self._n_groups = len(self._group_indices)
|
|
135
|
+
|
|
136
|
+
sizes = self._group_sizes
|
|
137
|
+
if len(sizes) > 0:
|
|
138
|
+
unique_sizes = np.unique(sizes)
|
|
139
|
+
self._all_equal_size = len(unique_sizes) == 1
|
|
140
|
+
if self._all_equal_size:
|
|
141
|
+
self._group_size_uniform = int(sizes[0])
|
|
142
|
+
|
|
143
|
+
self._is_contiguous = True
|
|
144
|
+
pos = 0
|
|
145
|
+
for g in range(self._n_groups):
|
|
146
|
+
sz = sizes[g]
|
|
147
|
+
if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
|
|
148
|
+
self._is_contiguous = False
|
|
149
|
+
break
|
|
150
|
+
pos += sz
|
|
151
|
+
|
|
152
|
+
if not self._is_contiguous:
|
|
153
|
+
self._flat_indices = np.concatenate(
|
|
154
|
+
[np.asarray(g, dtype=np.int64) for g in self._group_indices]
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Invalidate cached device tensors for _sqrt_pg
|
|
158
|
+
self._sqrt_pg_torch = None
|
|
159
|
+
self._sqrt_pg_cupy = None
|
|
160
|
+
|
|
161
|
+
# Precompute padded gather/scatter index arrays (for unequal groups)
|
|
162
|
+
if not self._all_equal_size:
|
|
163
|
+
self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
|
|
164
|
+
self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
|
|
165
|
+
|
|
166
|
+
# Precompute feature→group mapping (for gradient/lla_weights vectorization)
|
|
167
|
+
flat_indices = np.concatenate(
|
|
168
|
+
[np.asarray(g, dtype=np.int64) for g in self._group_indices]
|
|
169
|
+
)
|
|
170
|
+
if flat_indices.size == 0:
|
|
171
|
+
raise ValueError("groups must contain at least one feature index")
|
|
172
|
+
max_idx = int(flat_indices.max())
|
|
173
|
+
expected = max_idx + 1
|
|
174
|
+
unique_idx = np.unique(flat_indices)
|
|
175
|
+
if unique_idx.size != flat_indices.size:
|
|
176
|
+
raise ValueError("groups contain duplicate feature indices")
|
|
177
|
+
if unique_idx.size != expected:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"groups must cover a dense range of feature indices [0..max_index]"
|
|
180
|
+
)
|
|
181
|
+
self._group_feat_idx = np.empty(expected, dtype=np.int64)
|
|
182
|
+
for g, idx in enumerate(self._group_indices):
|
|
183
|
+
self._group_feat_idx[idx] = g
|
|
184
|
+
|
|
185
|
+
# Invalidate all cached device tensors
|
|
186
|
+
self._padded_row_idx_torch = None
|
|
187
|
+
self._padded_row_idx_cupy = None
|
|
188
|
+
self._padded_col_idx_torch = None
|
|
189
|
+
self._padded_col_idx_cupy = None
|
|
190
|
+
self._flat_indices_torch = None
|
|
191
|
+
self._flat_indices_cupy = None
|
|
192
|
+
self._group_feat_idx_torch = None
|
|
193
|
+
self._group_feat_idx_cupy = None
|
|
194
|
+
|
|
195
|
+
def _get_sqrt_pg(self, xp, w):
|
|
196
|
+
"""Cached device tensor for _sqrt_pg."""
|
|
197
|
+
if xp.__name__ == "torch":
|
|
198
|
+
if self._sqrt_pg_torch is None:
|
|
199
|
+
self._sqrt_pg_torch = _to_backend_array(self._sqrt_pg, xp, w)
|
|
200
|
+
return self._sqrt_pg_torch
|
|
201
|
+
else:
|
|
202
|
+
if self._sqrt_pg_cupy is None:
|
|
203
|
+
self._sqrt_pg_cupy = _to_backend_array(self._sqrt_pg, xp, w)
|
|
204
|
+
return self._sqrt_pg_cupy
|
|
205
|
+
|
|
206
|
+
def _get_cached(self, attr_name, xp, w):
|
|
207
|
+
"""Get or create cached device tensor for a numpy attribute."""
|
|
208
|
+
backend = "torch" if xp.__name__ == "torch" else "cupy"
|
|
209
|
+
cache_attr = f"_{attr_name}_{backend}"
|
|
210
|
+
cached = getattr(self, cache_attr, None)
|
|
211
|
+
if cached is None:
|
|
212
|
+
cached = _to_backend_array(getattr(self, attr_name), xp, w)
|
|
213
|
+
setattr(self, cache_attr, cached)
|
|
214
|
+
return cached
|
|
215
|
+
|
|
216
|
+
def _get_flat_indices(self, xp, w):
|
|
217
|
+
"""Cached device tensor for _flat_indices."""
|
|
218
|
+
if not hasattr(self, '_flat_indices') or self._flat_indices is None:
|
|
219
|
+
return None
|
|
220
|
+
return self._get_cached('_flat_indices', xp, w)
|
|
221
|
+
|
|
222
|
+
def _batched_group_norms_vec(self, coef_feat, xp, w_ref):
|
|
223
|
+
"""Vectorized batched group norms using padded fancy indexing."""
|
|
224
|
+
G = self._n_groups
|
|
225
|
+
max_sz = int(self._group_sizes.max())
|
|
226
|
+
padded = _backend_zeros((G, max_sz), xp, dtype=coef_feat.dtype, ref_arr=w_ref)
|
|
227
|
+
row_idx_dev = self._get_cached('_padded_row_idx', xp, w_ref)
|
|
228
|
+
col_idx_dev = self._get_cached('_padded_col_idx', xp, w_ref)
|
|
229
|
+
if self._is_contiguous:
|
|
230
|
+
padded[row_idx_dev, col_idx_dev] = coef_feat
|
|
231
|
+
else:
|
|
232
|
+
flat_idx_dev = self._get_flat_indices(xp, w_ref)
|
|
233
|
+
padded[row_idx_dev, col_idx_dev] = coef_feat[flat_idx_dev]
|
|
234
|
+
return _vector_norm(padded, xp, dim=1)
|
|
235
|
+
|
|
236
|
+
def _reshape_to_matrix(self, w, xp, G, gs):
|
|
237
|
+
"""Reshape w into (G, gs) matrix, handling non-contiguous layouts."""
|
|
238
|
+
p_total = G * gs
|
|
239
|
+
w_feat = w[:p_total] # handle augmented intercept
|
|
240
|
+
if self._is_contiguous:
|
|
241
|
+
return w_feat.reshape(G, gs)
|
|
242
|
+
return w_feat[self._flat_indices].reshape(G, gs)
|
|
243
|
+
|
|
244
|
+
def _scatter_from_flat(self, flat_vals, result, xp):
|
|
245
|
+
"""Scatter flat values back, handling non-contiguous layouts."""
|
|
246
|
+
p_total = len(flat_vals)
|
|
247
|
+
if self._is_contiguous:
|
|
248
|
+
result[:p_total] = flat_vals
|
|
249
|
+
else:
|
|
250
|
+
flat_idx = self._get_flat_indices(xp, result)
|
|
251
|
+
result[flat_idx] = flat_vals
|
|
252
|
+
|
|
253
|
+
# ----------------------------------------------------------------
|
|
254
|
+
# Value
|
|
255
|
+
# ----------------------------------------------------------------
|
|
256
|
+
|
|
257
|
+
def value(self, coef) -> float:
|
|
258
|
+
if self._group_indices is None:
|
|
259
|
+
raise ValueError("groups must be set before calling value()")
|
|
260
|
+
|
|
261
|
+
xp = _get_xp(coef)
|
|
262
|
+
is_torch = xp.__name__ == "torch"
|
|
263
|
+
is_cupy = xp.__name__ == "cupy"
|
|
264
|
+
|
|
265
|
+
p_total = int(self._group_sizes.sum())
|
|
266
|
+
coef_feat = coef[:p_total] # handle augmented intercept
|
|
267
|
+
|
|
268
|
+
# Compute all group norms in one batch (stays on device)
|
|
269
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
270
|
+
gs = self._group_size_uniform
|
|
271
|
+
if self._is_contiguous:
|
|
272
|
+
w_mat = coef_feat.reshape(self._n_groups, gs)
|
|
273
|
+
else:
|
|
274
|
+
w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
|
|
275
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
276
|
+
else:
|
|
277
|
+
norms = self._batched_group_norms_vec(coef_feat, xp, coef)
|
|
278
|
+
|
|
279
|
+
sqrt_pg = self._get_sqrt_pg(xp, coef)
|
|
280
|
+
alpha_g = self.alpha * sqrt_pg
|
|
281
|
+
gamma_alpha_g = self.gamma * alpha_g
|
|
282
|
+
|
|
283
|
+
if is_torch:
|
|
284
|
+
import torch
|
|
285
|
+
mask_small = norms <= gamma_alpha_g
|
|
286
|
+
total = torch.sum(alpha_g[mask_small] * norms[mask_small]
|
|
287
|
+
- norms[mask_small] ** 2 / (2.0 * self.gamma))
|
|
288
|
+
total += torch.sum(0.5 * self.gamma * alpha_g[~mask_small] ** 2)
|
|
289
|
+
return total.item()
|
|
290
|
+
elif is_cupy:
|
|
291
|
+
import cupy as cp
|
|
292
|
+
mask_small = norms <= gamma_alpha_g
|
|
293
|
+
total = cp.sum(alpha_g[mask_small] * norms[mask_small]
|
|
294
|
+
- norms[mask_small] ** 2 / (2.0 * self.gamma))
|
|
295
|
+
total += cp.sum(0.5 * self.gamma * alpha_g[~mask_small] ** 2)
|
|
296
|
+
return float(total)
|
|
297
|
+
else:
|
|
298
|
+
mask_small = norms <= gamma_alpha_g
|
|
299
|
+
total = np.sum(alpha_g[mask_small] * norms[mask_small]
|
|
300
|
+
- norms[mask_small] ** 2 / (2.0 * self.gamma))
|
|
301
|
+
total += np.sum(0.5 * self.gamma * alpha_g[~mask_small] ** 2)
|
|
302
|
+
return float(total)
|
|
303
|
+
|
|
304
|
+
# ----------------------------------------------------------------
|
|
305
|
+
# Gradient
|
|
306
|
+
# ----------------------------------------------------------------
|
|
307
|
+
|
|
308
|
+
def gradient(self, coef) -> np.ndarray:
|
|
309
|
+
if self._group_indices is None:
|
|
310
|
+
raise ValueError("groups must be set before calling gradient()")
|
|
311
|
+
|
|
312
|
+
xp = _get_xp(coef)
|
|
313
|
+
is_torch = xp.__name__ == "torch"
|
|
314
|
+
is_cupy = xp.__name__ == "cupy"
|
|
315
|
+
|
|
316
|
+
p_total = int(self._group_sizes.sum())
|
|
317
|
+
coef_feat = coef[:p_total] # handle augmented intercept
|
|
318
|
+
|
|
319
|
+
# Compute all group norms in one batch
|
|
320
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
321
|
+
gs = self._group_size_uniform
|
|
322
|
+
G = self._n_groups
|
|
323
|
+
if self._is_contiguous:
|
|
324
|
+
w_mat = coef_feat.reshape(G, gs)
|
|
325
|
+
else:
|
|
326
|
+
w_mat = coef_feat[self._flat_indices].reshape(G, gs)
|
|
327
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
328
|
+
else:
|
|
329
|
+
norms = self._batched_group_norms_vec(coef_feat, xp, coef)
|
|
330
|
+
|
|
331
|
+
sqrt_pg = self._get_sqrt_pg(xp, coef)
|
|
332
|
+
alpha_g = self.alpha * sqrt_pg
|
|
333
|
+
gamma_alpha_g = self.gamma * alpha_g
|
|
334
|
+
|
|
335
|
+
# Fused: single scale_g per group (eliminates intermediate deriv_g + inv_norms_g)
|
|
336
|
+
mask_active = (norms > 0) & (norms <= gamma_alpha_g)
|
|
337
|
+
safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
|
|
338
|
+
scale_g = xp.where(mask_active,
|
|
339
|
+
(alpha_g - norms / self.gamma) / safe_norms,
|
|
340
|
+
0.0)
|
|
341
|
+
|
|
342
|
+
feat_idx = self._get_cached('_group_feat_idx', xp, coef)
|
|
343
|
+
grad = xp.zeros_like(coef)
|
|
344
|
+
grad[:p_total] = scale_g[feat_idx] * coef_feat
|
|
345
|
+
return grad
|
|
346
|
+
|
|
347
|
+
# ----------------------------------------------------------------
|
|
348
|
+
# Proximal operator (group MCP)
|
|
349
|
+
# ----------------------------------------------------------------
|
|
350
|
+
|
|
351
|
+
def proximal(self, w, step: float, backend: str = "numpy"):
|
|
352
|
+
"""Per-group MCP proximal — vectorized on GPU."""
|
|
353
|
+
if self._group_indices is None:
|
|
354
|
+
raise ValueError("groups must be set before calling proximal()")
|
|
355
|
+
|
|
356
|
+
if backend == "cupy":
|
|
357
|
+
import cupy as cp
|
|
358
|
+
return self._proximal_vectorized(w, step, cp)
|
|
359
|
+
elif backend == "torch":
|
|
360
|
+
import torch
|
|
361
|
+
return self._proximal_vectorized(w, step, torch)
|
|
362
|
+
else:
|
|
363
|
+
return self._proximal_loop(w, step, np)
|
|
364
|
+
|
|
365
|
+
def _proximal_loop(self, w, step, xp):
|
|
366
|
+
step = min(float(step), 0.9 * self.gamma) # defense-in-depth clamping
|
|
367
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
368
|
+
for g, idx in enumerate(self._group_indices):
|
|
369
|
+
w_g = w[idx]
|
|
370
|
+
ng = float(xp.linalg.norm(w_g))
|
|
371
|
+
t_g = self.alpha * self._sqrt_pg[g] * step
|
|
372
|
+
gamma_alpha_g = self.gamma * self.alpha * self._sqrt_pg[g]
|
|
373
|
+
|
|
374
|
+
if ng <= t_g:
|
|
375
|
+
result[idx] = 0.0
|
|
376
|
+
elif t_g < ng <= gamma_alpha_g:
|
|
377
|
+
scale = (ng - t_g) / (ng * (1.0 - step / self.gamma))
|
|
378
|
+
result[idx] = w_g * scale
|
|
379
|
+
else:
|
|
380
|
+
result[idx] = w_g
|
|
381
|
+
return result
|
|
382
|
+
|
|
383
|
+
def _proximal_vectorized(self, w, step, xp):
|
|
384
|
+
"""Vectorized group MCP proximal."""
|
|
385
|
+
G = self._n_groups
|
|
386
|
+
|
|
387
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
388
|
+
gs = self._group_size_uniform
|
|
389
|
+
return self._proximal_equal(w, step, xp, G, gs)
|
|
390
|
+
|
|
391
|
+
max_sz = int(self._group_sizes.max())
|
|
392
|
+
return self._proximal_padded(w, step, xp, G, max_sz)
|
|
393
|
+
|
|
394
|
+
def _proximal_equal(self, w, step, xp, G, gs):
|
|
395
|
+
"""Equal-size groups: vectorized MCP proximal."""
|
|
396
|
+
# Clamp step to prevent division by zero in denom = norms*(1 - step/gamma)
|
|
397
|
+
step = min(float(step), 0.9 * self.gamma)
|
|
398
|
+
w_mat = self._reshape_to_matrix(w, xp, G, gs)
|
|
399
|
+
sqrt_pg_arr = self._get_sqrt_pg(xp, w)
|
|
400
|
+
|
|
401
|
+
# Torch compiled fast path
|
|
402
|
+
if xp.__name__ == "torch":
|
|
403
|
+
compiled_fn = _get_group_mcp_torch_compiled()
|
|
404
|
+
if compiled_fn is not None:
|
|
405
|
+
scaled_flat = compiled_fn(w_mat, sqrt_pg_arr, self.alpha, step, self.gamma)
|
|
406
|
+
result = w.clone()
|
|
407
|
+
self._scatter_from_flat(scaled_flat, result, xp)
|
|
408
|
+
return result
|
|
409
|
+
|
|
410
|
+
# Generic vectorized path
|
|
411
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
412
|
+
t_g = self.alpha * sqrt_pg_arr * step # (G,)
|
|
413
|
+
gamma_alpha_g = self.gamma * self.alpha * sqrt_pg_arr # (G,)
|
|
414
|
+
|
|
415
|
+
# Region 1: norm <= t_g → zero
|
|
416
|
+
mask_zero = norms <= t_g
|
|
417
|
+
# Region 2: t_g < norm <= gamma_alpha_g → MCP shrinkage
|
|
418
|
+
mask_shrink = (norms > t_g) & (norms <= gamma_alpha_g)
|
|
419
|
+
# Region 3: norm > gamma_alpha_g → no shrinkage (identity)
|
|
420
|
+
|
|
421
|
+
denom = norms * (1.0 - step / self.gamma)
|
|
422
|
+
denom = xp.where(mask_shrink, denom, xp.ones_like(denom))
|
|
423
|
+
scale_shrink = (norms - t_g) / denom # (G,)
|
|
424
|
+
scale = xp.where(mask_shrink, scale_shrink, 1.0) # (G,)
|
|
425
|
+
scale = xp.where(mask_zero, 0.0, scale)
|
|
426
|
+
|
|
427
|
+
scaled_flat = (w_mat * scale[:, None]).reshape(-1)
|
|
428
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
429
|
+
self._scatter_from_flat(scaled_flat, result, xp)
|
|
430
|
+
return result
|
|
431
|
+
|
|
432
|
+
def _proximal_padded(self, w, step, xp, G, max_sz):
|
|
433
|
+
"""Unequal groups: pad, vectorize, unpack."""
|
|
434
|
+
step = min(float(step), 0.9 * self.gamma)
|
|
435
|
+
p_total = int(self._group_sizes.sum())
|
|
436
|
+
w_feat = w[:p_total] # handle augmented intercept
|
|
437
|
+
|
|
438
|
+
# Build padded matrix via fancy indexing — 1 kernel launch
|
|
439
|
+
padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
|
|
440
|
+
row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
|
|
441
|
+
col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
|
|
442
|
+
if self._is_contiguous:
|
|
443
|
+
padded[row_idx_dev, col_idx_dev] = w_feat
|
|
444
|
+
else:
|
|
445
|
+
flat_idx_dev = self._get_flat_indices(xp, w)
|
|
446
|
+
padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
|
|
447
|
+
|
|
448
|
+
norms = _vector_norm(padded, xp, dim=1)
|
|
449
|
+
sqrt_pg_arr = self._get_sqrt_pg(xp, w)
|
|
450
|
+
t_g = self.alpha * sqrt_pg_arr * step
|
|
451
|
+
gamma_alpha_g = self.gamma * self.alpha * sqrt_pg_arr
|
|
452
|
+
|
|
453
|
+
mask_zero = norms <= t_g
|
|
454
|
+
mask_shrink = (norms > t_g) & (norms <= gamma_alpha_g)
|
|
455
|
+
denom = norms * (1.0 - step / self.gamma)
|
|
456
|
+
denom = xp.where(mask_shrink, denom, xp.ones_like(denom))
|
|
457
|
+
scale_shrink = (norms - t_g) / denom
|
|
458
|
+
scale = xp.where(mask_shrink, scale_shrink, 1.0)
|
|
459
|
+
scale = xp.where(mask_zero, 0.0, scale)
|
|
460
|
+
|
|
461
|
+
padded_scaled = padded * scale[:, None]
|
|
462
|
+
|
|
463
|
+
# Scatter back via fancy indexing — 1 kernel launch
|
|
464
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
465
|
+
if self._is_contiguous:
|
|
466
|
+
result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
|
|
467
|
+
else:
|
|
468
|
+
result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
|
|
469
|
+
return result
|
|
470
|
+
|
|
471
|
+
# ----------------------------------------------------------------
|
|
472
|
+
# LLA weights (for LLA outer loop optimization)
|
|
473
|
+
# ----------------------------------------------------------------
|
|
474
|
+
|
|
475
|
+
def lla_weights(self, coef):
|
|
476
|
+
if self._group_indices is None:
|
|
477
|
+
raise ValueError("groups must be set before calling lla_weights()")
|
|
478
|
+
|
|
479
|
+
xp = _get_xp(coef)
|
|
480
|
+
is_torch = xp.__name__ == "torch"
|
|
481
|
+
is_cupy = xp.__name__ == "cupy"
|
|
482
|
+
|
|
483
|
+
p_total = int(self._group_sizes.sum())
|
|
484
|
+
coef_feat = coef[:p_total] # handle augmented intercept
|
|
485
|
+
|
|
486
|
+
# Compute all group norms in one batch
|
|
487
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
488
|
+
gs = self._group_size_uniform
|
|
489
|
+
if self._is_contiguous:
|
|
490
|
+
w_mat = coef_feat.reshape(self._n_groups, gs)
|
|
491
|
+
else:
|
|
492
|
+
w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
|
|
493
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
494
|
+
else:
|
|
495
|
+
norms = self._batched_group_norms_vec(coef_feat, xp, coef)
|
|
496
|
+
|
|
497
|
+
sqrt_pg = self._get_sqrt_pg(xp, coef)
|
|
498
|
+
alpha_g = self.alpha * sqrt_pg
|
|
499
|
+
gamma_alpha_g = self.gamma * alpha_g
|
|
500
|
+
|
|
501
|
+
# Per-group derivative weight
|
|
502
|
+
if is_torch:
|
|
503
|
+
import torch
|
|
504
|
+
weight_g = torch.where(
|
|
505
|
+
norms <= gamma_alpha_g,
|
|
506
|
+
torch.clamp(alpha_g - norms / self.gamma, min=0.0),
|
|
507
|
+
torch.zeros_like(norms),
|
|
508
|
+
)
|
|
509
|
+
# Broadcast to per-coordinate
|
|
510
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
511
|
+
gs = self._group_size_uniform
|
|
512
|
+
weights = weight_g.repeat_interleave(gs)
|
|
513
|
+
else:
|
|
514
|
+
feat_idx = self._get_cached('_group_feat_idx', xp, coef)
|
|
515
|
+
weights = weight_g[feat_idx]
|
|
516
|
+
return weights
|
|
517
|
+
elif is_cupy:
|
|
518
|
+
import cupy as cp
|
|
519
|
+
weight_g = cp.where(
|
|
520
|
+
norms <= gamma_alpha_g,
|
|
521
|
+
cp.maximum(alpha_g - norms / self.gamma, 0.0),
|
|
522
|
+
0.0,
|
|
523
|
+
)
|
|
524
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
525
|
+
gs = self._group_size_uniform
|
|
526
|
+
weights = cp.repeat(weight_g, gs)
|
|
527
|
+
else:
|
|
528
|
+
feat_idx = self._get_cached('_group_feat_idx', xp, coef)
|
|
529
|
+
weights = weight_g[feat_idx]
|
|
530
|
+
return weights
|
|
531
|
+
else:
|
|
532
|
+
weight_g = np.where(
|
|
533
|
+
norms <= gamma_alpha_g,
|
|
534
|
+
np.maximum(alpha_g - norms / self.gamma, 0.0),
|
|
535
|
+
0.0,
|
|
536
|
+
)
|
|
537
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
538
|
+
gs = self._group_size_uniform
|
|
539
|
+
weights = np.repeat(weight_g, gs)
|
|
540
|
+
else:
|
|
541
|
+
weights = weight_g[self._group_feat_idx]
|
|
542
|
+
return weights
|
|
543
|
+
|
|
544
|
+
# ----------------------------------------------------------------
|
|
545
|
+
|
|
546
|
+
def get_params(self) -> dict:
|
|
547
|
+
params = super().get_params()
|
|
548
|
+
params.update({
|
|
549
|
+
"alpha": self.alpha,
|
|
550
|
+
"gamma": self.gamma,
|
|
551
|
+
"n_groups": self._n_groups,
|
|
552
|
+
})
|
|
553
|
+
return params
|