statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Group Lasso penalty.
|
|
3
|
+
|
|
4
|
+
Yuan & Lin, JRSSB 2006. Convex penalty that selects groups of features.
|
|
5
|
+
|
|
6
|
+
The penalty is:
|
|
7
|
+
P(w) = alpha * sum_g sqrt(p_g) * ||w_g||_2
|
|
8
|
+
|
|
9
|
+
where w_g is the subvector of w for group g, and p_g is the size of group g.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
__all__ = ["GroupLassoPenalty", "AdaptiveGroupLassoPenalty"]
|
|
13
|
+
|
|
14
|
+
from typing import Optional, List, Union
|
|
15
|
+
import numpy as np
|
|
16
|
+
from statgpu.penalties._base import Penalty
|
|
17
|
+
|
|
18
|
+
# ---- torch.compile lazy-loader for vectorized proximal on GPU ---------
|
|
19
|
+
_GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_group_lasso_torch_compiled_equal():
|
|
23
|
+
"""torch.compile'd equal-size group lasso proximal (G,gs)→norms→scale→flat."""
|
|
24
|
+
global _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL
|
|
25
|
+
if _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL is not None:
|
|
26
|
+
return _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL
|
|
27
|
+
from statgpu.penalties import _torch_compile_ok
|
|
28
|
+
if not _torch_compile_ok():
|
|
29
|
+
_GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = None
|
|
30
|
+
return None
|
|
31
|
+
try:
|
|
32
|
+
import torch
|
|
33
|
+
def _prox(w_mat, sqrt_pg, alpha, step):
|
|
34
|
+
thresh = alpha * sqrt_pg * step
|
|
35
|
+
norms = torch.linalg.norm(w_mat, dim=1)
|
|
36
|
+
scale = torch.clamp(1.0 - thresh / (norms + 1e-12), min=0.0)
|
|
37
|
+
return (w_mat * scale[:, None]).reshape(-1)
|
|
38
|
+
_GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = torch.compile(
|
|
39
|
+
_prox, dynamic=True, mode='reduce-overhead'
|
|
40
|
+
)
|
|
41
|
+
except Exception:
|
|
42
|
+
_GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL = None
|
|
43
|
+
return _GROUP_LASSO_PROXIMAL_TORCH_COMPILED_EQUAL
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _vector_norm(x, xp, dim=None):
|
|
47
|
+
"""Backend-aware L2 norm along a dimension."""
|
|
48
|
+
if xp.__name__ == "torch":
|
|
49
|
+
return xp.linalg.norm(x, dim=dim) if dim is not None else xp.linalg.norm(x)
|
|
50
|
+
return xp.linalg.norm(x, axis=dim) if dim is not None else xp.linalg.norm(x)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _to_backend_array(arr, xp, ref_arr=None):
|
|
54
|
+
"""Convert numpy array to backend array type."""
|
|
55
|
+
if xp.__name__ == "torch":
|
|
56
|
+
import torch
|
|
57
|
+
arr_np = np.asarray(arr)
|
|
58
|
+
# Preserve int types (needed for indexing), convert others to float64
|
|
59
|
+
if arr_np.dtype.kind in ('i', 'u'):
|
|
60
|
+
t = torch.from_numpy(arr_np)
|
|
61
|
+
else:
|
|
62
|
+
t = torch.from_numpy(arr_np.astype(np.float64))
|
|
63
|
+
if ref_arr is not None:
|
|
64
|
+
t = t.to(device=ref_arr.device)
|
|
65
|
+
return t
|
|
66
|
+
return xp.asarray(arr)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _backend_zeros(shape, xp, dtype=None, ref_arr=None):
|
|
70
|
+
"""Create zeros array on the correct backend."""
|
|
71
|
+
if xp.__name__ == "torch":
|
|
72
|
+
import torch
|
|
73
|
+
t = torch.zeros(shape, dtype=dtype if dtype is not None else torch.float64)
|
|
74
|
+
if ref_arr is not None:
|
|
75
|
+
t = t.to(device=ref_arr.device)
|
|
76
|
+
return t
|
|
77
|
+
return xp.zeros(shape, dtype=dtype)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _batched_group_norms(coef, group_indices, xp):
|
|
81
|
+
"""Compute L2 norms for each group, all on device. Returns (G,) array."""
|
|
82
|
+
norms_list = []
|
|
83
|
+
for idx in group_indices:
|
|
84
|
+
if len(idx) > 0:
|
|
85
|
+
norms_list.append(_vector_norm(coef[idx], xp))
|
|
86
|
+
else:
|
|
87
|
+
if xp.__name__ == "torch":
|
|
88
|
+
norms_list.append(xp.zeros(1, device=coef.device, dtype=coef.dtype)[0])
|
|
89
|
+
elif xp.__name__ == "cupy":
|
|
90
|
+
norms_list.append(xp.zeros(1, dtype=coef.dtype)[0])
|
|
91
|
+
else:
|
|
92
|
+
norms_list.append(0.0)
|
|
93
|
+
if xp.__name__ == "torch":
|
|
94
|
+
return xp.stack(norms_list)
|
|
95
|
+
elif xp.__name__ == "cupy":
|
|
96
|
+
return xp.array(norms_list)
|
|
97
|
+
return np.array(norms_list)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# Use canonical _xp from backends (replaces local _get_xp)
|
|
101
|
+
from statgpu.backends._array_ops import _xp as _get_xp
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class GroupLassoPenalty(Penalty):
|
|
105
|
+
"""Group Lasso penalty.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
alpha : float, default=1.0
|
|
110
|
+
Regularization strength.
|
|
111
|
+
groups : list of lists, or 1D array-like
|
|
112
|
+
Group membership specification. Two forms accepted:
|
|
113
|
+
- List of lists of feature indices, e.g. [[0,1], [2,3,4]]
|
|
114
|
+
- 1D array of length n_features where each entry is the group ID
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
name = "group_lasso"
|
|
118
|
+
is_convex = True
|
|
119
|
+
supports_group = True
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
alpha: float = 1.0,
|
|
124
|
+
groups=None,
|
|
125
|
+
):
|
|
126
|
+
self.alpha = alpha
|
|
127
|
+
self._group_indices = None
|
|
128
|
+
self._group_sizes = None
|
|
129
|
+
self._all_equal_size = False
|
|
130
|
+
self._is_contiguous = False
|
|
131
|
+
self._group_size_uniform = None
|
|
132
|
+
self._flat_indices = None
|
|
133
|
+
|
|
134
|
+
if groups is not None:
|
|
135
|
+
self._init_groups(groups)
|
|
136
|
+
|
|
137
|
+
def _init_groups(self, groups):
|
|
138
|
+
"""Parse group specification into internal format."""
|
|
139
|
+
if isinstance(groups, np.ndarray) and groups.ndim == 1:
|
|
140
|
+
group_ids = np.asarray(groups, dtype=int)
|
|
141
|
+
n_groups = int(group_ids.max() + 1)
|
|
142
|
+
self._group_indices = [
|
|
143
|
+
np.where(group_ids == g)[0] for g in range(n_groups)
|
|
144
|
+
]
|
|
145
|
+
elif isinstance(groups, (list, tuple)):
|
|
146
|
+
if len(groups) == 0:
|
|
147
|
+
raise ValueError("groups must not be empty")
|
|
148
|
+
if isinstance(groups[0], (list, tuple, np.ndarray)):
|
|
149
|
+
self._group_indices = [
|
|
150
|
+
np.asarray(g, dtype=int) for g in groups
|
|
151
|
+
]
|
|
152
|
+
else:
|
|
153
|
+
group_ids = np.asarray(groups, dtype=int)
|
|
154
|
+
n_groups = int(group_ids.max() + 1)
|
|
155
|
+
self._group_indices = [
|
|
156
|
+
np.where(group_ids == g)[0] for g in range(n_groups)
|
|
157
|
+
]
|
|
158
|
+
else:
|
|
159
|
+
raise TypeError(
|
|
160
|
+
f"groups must be list or array, got {type(groups).__name__}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
self._group_sizes = np.array(
|
|
164
|
+
[len(g) for g in self._group_indices], dtype=int
|
|
165
|
+
)
|
|
166
|
+
self._sqrt_pg = np.sqrt(self._group_sizes.astype(float))
|
|
167
|
+
self._n_groups = len(self._group_indices)
|
|
168
|
+
|
|
169
|
+
# Detect equal-size groups for fast vectorized path
|
|
170
|
+
sizes = self._group_sizes
|
|
171
|
+
if len(sizes) > 0:
|
|
172
|
+
unique_sizes = np.unique(sizes)
|
|
173
|
+
self._all_equal_size = len(unique_sizes) == 1
|
|
174
|
+
if self._all_equal_size:
|
|
175
|
+
self._group_size_uniform = int(sizes[0])
|
|
176
|
+
|
|
177
|
+
# Check if groups are contiguous [0..p1-1], [p1..p1+p2-1], ...
|
|
178
|
+
self._is_contiguous = True
|
|
179
|
+
pos = 0
|
|
180
|
+
for g in range(self._n_groups):
|
|
181
|
+
sz = sizes[g]
|
|
182
|
+
if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
|
|
183
|
+
self._is_contiguous = False
|
|
184
|
+
break
|
|
185
|
+
pos += sz
|
|
186
|
+
|
|
187
|
+
# Precompute flat indices for gather/scatter (only needed if non-contiguous)
|
|
188
|
+
if not self._is_contiguous:
|
|
189
|
+
self._flat_indices = np.concatenate(
|
|
190
|
+
[np.asarray(g, dtype=np.int64) for g in self._group_indices]
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Invalidate cached device tensors for _sqrt_pg
|
|
194
|
+
self._sqrt_pg_torch = None
|
|
195
|
+
self._sqrt_pg_cupy = None
|
|
196
|
+
|
|
197
|
+
# Precompute padded gather/scatter index arrays (for unequal groups)
|
|
198
|
+
if not self._all_equal_size:
|
|
199
|
+
self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
|
|
200
|
+
self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
|
|
201
|
+
|
|
202
|
+
# Precompute feature→group mapping (for gradient/lla_weights vectorization)
|
|
203
|
+
flat_indices = np.concatenate(
|
|
204
|
+
[np.asarray(g, dtype=np.int64) for g in self._group_indices]
|
|
205
|
+
)
|
|
206
|
+
if flat_indices.size == 0:
|
|
207
|
+
raise ValueError("groups must contain at least one feature index")
|
|
208
|
+
max_idx = int(flat_indices.max())
|
|
209
|
+
expected = max_idx + 1
|
|
210
|
+
unique_idx = np.unique(flat_indices)
|
|
211
|
+
if unique_idx.size != flat_indices.size:
|
|
212
|
+
raise ValueError("groups contain duplicate feature indices")
|
|
213
|
+
if unique_idx.size != expected:
|
|
214
|
+
# Auto-fill missing indices as single-feature groups
|
|
215
|
+
import warnings
|
|
216
|
+
all_indices = set(range(expected))
|
|
217
|
+
covered = set(unique_idx.tolist())
|
|
218
|
+
missing = sorted(all_indices - covered)
|
|
219
|
+
if missing:
|
|
220
|
+
warnings.warn(
|
|
221
|
+
f"Groups do not cover features {missing}. "
|
|
222
|
+
f"Auto-adding {len(missing)} single-feature groups.",
|
|
223
|
+
UserWarning, stacklevel=2,
|
|
224
|
+
)
|
|
225
|
+
for idx in missing:
|
|
226
|
+
self._group_indices.append([idx])
|
|
227
|
+
flat_indices = np.concatenate(self._group_indices)
|
|
228
|
+
unique_idx = np.unique(flat_indices)
|
|
229
|
+
# Update derived attributes after auto-fill
|
|
230
|
+
self._n_groups = len(self._group_indices)
|
|
231
|
+
self._group_sizes = np.array([len(g) for g in self._group_indices], dtype=np.int64)
|
|
232
|
+
sizes = self._group_sizes
|
|
233
|
+
unique_sizes = np.unique(sizes)
|
|
234
|
+
self._all_equal_size = len(unique_sizes) == 1
|
|
235
|
+
if self._all_equal_size:
|
|
236
|
+
self._group_size_uniform = int(sizes[0])
|
|
237
|
+
# Recompute contiguity
|
|
238
|
+
self._is_contiguous = True
|
|
239
|
+
pos = 0
|
|
240
|
+
for g in range(self._n_groups):
|
|
241
|
+
sz = sizes[g]
|
|
242
|
+
if not np.array_equal(self._group_indices[g], np.arange(pos, pos + sz)):
|
|
243
|
+
self._is_contiguous = False
|
|
244
|
+
break
|
|
245
|
+
pos += sz
|
|
246
|
+
if not self._is_contiguous:
|
|
247
|
+
self._flat_indices = np.concatenate(
|
|
248
|
+
[np.asarray(g, dtype=np.int64) for g in self._group_indices]
|
|
249
|
+
)
|
|
250
|
+
# Recompute padded indices for unequal groups
|
|
251
|
+
if not self._all_equal_size:
|
|
252
|
+
self._padded_row_idx = np.repeat(np.arange(self._n_groups), self._group_sizes).astype(np.int64)
|
|
253
|
+
self._padded_col_idx = np.concatenate([np.arange(sz) for sz in self._group_sizes]).astype(np.int64)
|
|
254
|
+
self._group_feat_idx = np.empty(expected, dtype=np.int64)
|
|
255
|
+
for g, idx in enumerate(self._group_indices):
|
|
256
|
+
self._group_feat_idx[idx] = g
|
|
257
|
+
|
|
258
|
+
# Invalidate all cached device tensors
|
|
259
|
+
self._padded_row_idx_torch = None
|
|
260
|
+
self._padded_row_idx_cupy = None
|
|
261
|
+
self._padded_col_idx_torch = None
|
|
262
|
+
self._padded_col_idx_cupy = None
|
|
263
|
+
self._flat_indices_torch = None
|
|
264
|
+
self._flat_indices_cupy = None
|
|
265
|
+
self._group_feat_idx_torch = None
|
|
266
|
+
self._group_feat_idx_cupy = None
|
|
267
|
+
|
|
268
|
+
# ----------------------------------------------------------------
|
|
269
|
+
# Value
|
|
270
|
+
# ----------------------------------------------------------------
|
|
271
|
+
|
|
272
|
+
def value(self, coef) -> float:
|
|
273
|
+
if self._group_indices is None:
|
|
274
|
+
raise ValueError("groups must be set before calling value()")
|
|
275
|
+
|
|
276
|
+
xp = _get_xp(coef)
|
|
277
|
+
is_torch = xp.__name__ == "torch"
|
|
278
|
+
is_cupy = xp.__name__ == "cupy"
|
|
279
|
+
|
|
280
|
+
p_total = int(self._group_sizes.sum())
|
|
281
|
+
coef_feat = coef[:p_total] # handle augmented intercept
|
|
282
|
+
|
|
283
|
+
# Compute all group norms in one batch (stays on device)
|
|
284
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
285
|
+
gs = self._group_size_uniform
|
|
286
|
+
if self._is_contiguous:
|
|
287
|
+
w_mat = coef_feat.reshape(self._n_groups, gs)
|
|
288
|
+
else:
|
|
289
|
+
w_mat = coef_feat[self._flat_indices].reshape(self._n_groups, gs)
|
|
290
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
291
|
+
else:
|
|
292
|
+
norms = self._batched_group_norms_vec(coef_feat, xp, coef)
|
|
293
|
+
|
|
294
|
+
sqrt_pg = self._get_sqrt_pg(xp, coef)
|
|
295
|
+
|
|
296
|
+
if is_torch:
|
|
297
|
+
return xp.sum(self.alpha * sqrt_pg * norms).item()
|
|
298
|
+
elif is_cupy:
|
|
299
|
+
return float(xp.sum(self.alpha * sqrt_pg * norms))
|
|
300
|
+
else:
|
|
301
|
+
return float(np.sum(self.alpha * sqrt_pg * norms))
|
|
302
|
+
|
|
303
|
+
# ----------------------------------------------------------------
|
|
304
|
+
# Gradient
|
|
305
|
+
# ----------------------------------------------------------------
|
|
306
|
+
|
|
307
|
+
def gradient(self, coef) -> np.ndarray:
|
|
308
|
+
if self._group_indices is None:
|
|
309
|
+
raise ValueError("groups must be set before calling gradient()")
|
|
310
|
+
|
|
311
|
+
xp = _get_xp(coef)
|
|
312
|
+
is_torch = xp.__name__ == "torch"
|
|
313
|
+
is_cupy = xp.__name__ == "cupy"
|
|
314
|
+
|
|
315
|
+
p_total = int(self._group_sizes.sum())
|
|
316
|
+
coef_feat = coef[:p_total] # handle augmented intercept
|
|
317
|
+
|
|
318
|
+
# Equal-size groups: fully vectorized path
|
|
319
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
320
|
+
gs = self._group_size_uniform
|
|
321
|
+
G = self._n_groups
|
|
322
|
+
if self._is_contiguous:
|
|
323
|
+
w_mat = coef_feat.reshape(G, gs)
|
|
324
|
+
else:
|
|
325
|
+
w_mat = coef_feat[self._flat_indices].reshape(G, gs)
|
|
326
|
+
|
|
327
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
328
|
+
sqrt_pg = self._get_sqrt_pg(xp, coef)
|
|
329
|
+
|
|
330
|
+
# Unified path for all backends
|
|
331
|
+
safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
|
|
332
|
+
scale = xp.where(norms > 1e-15,
|
|
333
|
+
self.alpha * sqrt_pg / safe_norms,
|
|
334
|
+
0.0)
|
|
335
|
+
grad_mat = w_mat * scale[:, None]
|
|
336
|
+
if is_torch or is_cupy:
|
|
337
|
+
grad = xp.zeros_like(coef)
|
|
338
|
+
else:
|
|
339
|
+
grad = np.zeros_like(coef, dtype=float)
|
|
340
|
+
if self._is_contiguous:
|
|
341
|
+
grad[:p_total] = grad_mat.reshape(-1)
|
|
342
|
+
else:
|
|
343
|
+
grad[self._flat_indices] = grad_mat.reshape(-1)
|
|
344
|
+
return grad
|
|
345
|
+
|
|
346
|
+
# Unequal groups: vectorized scale + scatter via _group_feat_idx
|
|
347
|
+
norms = self._batched_group_norms_vec(coef_feat, xp, coef)
|
|
348
|
+
sqrt_pg = self._get_sqrt_pg(xp, coef)
|
|
349
|
+
|
|
350
|
+
# Fused: single scale_g (eliminates separate safe_norms + where)
|
|
351
|
+
safe_norms = xp.clamp(norms, min=1e-15) if is_torch else xp.maximum(norms, 1e-15)
|
|
352
|
+
scale_g = xp.where(norms > 1e-15,
|
|
353
|
+
self.alpha * sqrt_pg / safe_norms,
|
|
354
|
+
0.0)
|
|
355
|
+
|
|
356
|
+
feat_idx = self._get_cached('_group_feat_idx', xp, coef)
|
|
357
|
+
grad = xp.zeros_like(coef)
|
|
358
|
+
grad[:p_total] = scale_g[feat_idx] * coef_feat
|
|
359
|
+
return grad
|
|
360
|
+
|
|
361
|
+
# ----------------------------------------------------------------
|
|
362
|
+
# Proximal operator (block soft-thresholding)
|
|
363
|
+
# ----------------------------------------------------------------
|
|
364
|
+
|
|
365
|
+
def proximal(self, w, step: float, backend: str = "numpy"):
|
|
366
|
+
"""Group soft-thresholding: each group is shrunk toward zero.
|
|
367
|
+
|
|
368
|
+
GPU backends use vectorized reshape + axis-norm instead of a per-group
|
|
369
|
+
serial loop, eliminating G× kernel-launch + D2H-sync overhead.
|
|
370
|
+
"""
|
|
371
|
+
if self._group_indices is None:
|
|
372
|
+
raise ValueError("groups must be set before calling proximal()")
|
|
373
|
+
|
|
374
|
+
if backend == "cupy":
|
|
375
|
+
import cupy as cp
|
|
376
|
+
return self._proximal_vectorized(w, step, cp)
|
|
377
|
+
elif backend == "torch":
|
|
378
|
+
import torch
|
|
379
|
+
return self._proximal_vectorized(w, step, torch)
|
|
380
|
+
else:
|
|
381
|
+
return self._proximal_loop(w, step, np)
|
|
382
|
+
|
|
383
|
+
def _proximal_loop(self, w, step, xp):
|
|
384
|
+
"""Per-group serial loop (numpy CPU path)."""
|
|
385
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
386
|
+
for g, idx in enumerate(self._group_indices):
|
|
387
|
+
w_g = w[idx]
|
|
388
|
+
norm = float(xp.linalg.norm(w_g))
|
|
389
|
+
thresh = self.alpha * self._sqrt_pg[g] * step
|
|
390
|
+
if norm > thresh:
|
|
391
|
+
result[idx] = w_g * (1.0 - thresh / norm)
|
|
392
|
+
else:
|
|
393
|
+
result[idx] = 0.0
|
|
394
|
+
return result
|
|
395
|
+
|
|
396
|
+
def _proximal_vectorized(self, w, step, xp):
|
|
397
|
+
"""Vectorized proximal: reshape groups into (G, gs) matrix, compute
|
|
398
|
+
norms in one kernel, scale in one broadcast — O(1) kernel launches.
|
|
399
|
+
|
|
400
|
+
For non-contiguous group layouts, a gather/scatter pass is added.
|
|
401
|
+
"""
|
|
402
|
+
G = self._n_groups
|
|
403
|
+
|
|
404
|
+
if self._all_equal_size and self._group_size_uniform is not None:
|
|
405
|
+
gs = self._group_size_uniform
|
|
406
|
+
return self._proximal_equal(w, step, xp, G, gs)
|
|
407
|
+
|
|
408
|
+
# Unequal groups: pad to max size
|
|
409
|
+
max_sz = int(self._group_sizes.max())
|
|
410
|
+
return self._proximal_padded(w, step, xp, G, max_sz)
|
|
411
|
+
|
|
412
|
+
def _gather(self, w, xp):
|
|
413
|
+
"""Permute w so groups are contiguous. Only valid for equal-size groups."""
|
|
414
|
+
if not self._all_equal_size:
|
|
415
|
+
raise ValueError("_gather requires equal-size groups; use _proximal_padded instead")
|
|
416
|
+
if self._is_contiguous:
|
|
417
|
+
return w.reshape(self._n_groups, self._group_size_uniform)
|
|
418
|
+
return w[self._flat_indices].reshape(self._n_groups, self._group_size_uniform)
|
|
419
|
+
|
|
420
|
+
def _scatter(self, w_mat_flat, result, xp):
|
|
421
|
+
"""Scatter vectorized result back. No-op if already contiguous."""
|
|
422
|
+
if self._is_contiguous:
|
|
423
|
+
result[:] = w_mat_flat
|
|
424
|
+
else:
|
|
425
|
+
result[self._flat_indices] = w_mat_flat
|
|
426
|
+
return result
|
|
427
|
+
|
|
428
|
+
def _get_sqrt_pg(self, xp, w):
|
|
429
|
+
"""Cached device tensor for _sqrt_pg."""
|
|
430
|
+
if xp.__name__ == "torch":
|
|
431
|
+
if self._sqrt_pg_torch is None:
|
|
432
|
+
self._sqrt_pg_torch = _to_backend_array(self._sqrt_pg, xp, w)
|
|
433
|
+
return self._sqrt_pg_torch
|
|
434
|
+
else:
|
|
435
|
+
if self._sqrt_pg_cupy is None:
|
|
436
|
+
self._sqrt_pg_cupy = _to_backend_array(self._sqrt_pg, xp, w)
|
|
437
|
+
return self._sqrt_pg_cupy
|
|
438
|
+
|
|
439
|
+
def _get_cached(self, attr_name, xp, w):
|
|
440
|
+
"""Get or create cached device tensor for a numpy attribute."""
|
|
441
|
+
backend = "torch" if xp.__name__ == "torch" else "cupy"
|
|
442
|
+
cache_attr = f"_{attr_name}_{backend}"
|
|
443
|
+
cached = getattr(self, cache_attr, None)
|
|
444
|
+
if cached is None:
|
|
445
|
+
cached = _to_backend_array(getattr(self, attr_name), xp, w)
|
|
446
|
+
setattr(self, cache_attr, cached)
|
|
447
|
+
return cached
|
|
448
|
+
|
|
449
|
+
def _get_flat_indices(self, xp, w):
|
|
450
|
+
"""Cached device tensor for _flat_indices."""
|
|
451
|
+
if not hasattr(self, '_flat_indices') or self._flat_indices is None:
|
|
452
|
+
return None
|
|
453
|
+
return self._get_cached('_flat_indices', xp, w)
|
|
454
|
+
|
|
455
|
+
def _batched_group_norms_vec(self, coef_feat, xp, w_ref):
|
|
456
|
+
"""Vectorized batched group norms using padded fancy indexing.
|
|
457
|
+
|
|
458
|
+
Replaces _batched_group_norms() Python loop with 3 kernels:
|
|
459
|
+
1. zeros allocation
|
|
460
|
+
2. fancy index scatter
|
|
461
|
+
3. vectorized norm along dim=1
|
|
462
|
+
"""
|
|
463
|
+
G = self._n_groups
|
|
464
|
+
max_sz = int(self._group_sizes.max())
|
|
465
|
+
padded = _backend_zeros((G, max_sz), xp, dtype=coef_feat.dtype, ref_arr=w_ref)
|
|
466
|
+
row_idx_dev = self._get_cached('_padded_row_idx', xp, w_ref)
|
|
467
|
+
col_idx_dev = self._get_cached('_padded_col_idx', xp, w_ref)
|
|
468
|
+
if self._is_contiguous:
|
|
469
|
+
padded[row_idx_dev, col_idx_dev] = coef_feat
|
|
470
|
+
else:
|
|
471
|
+
flat_idx_dev = self._get_flat_indices(xp, w_ref)
|
|
472
|
+
padded[row_idx_dev, col_idx_dev] = coef_feat[flat_idx_dev]
|
|
473
|
+
return _vector_norm(padded, xp, dim=1)
|
|
474
|
+
|
|
475
|
+
def _proximal_equal(self, w, step, xp, G, gs):
|
|
476
|
+
"""Fast path: all groups equal size, vectorized norm + scale."""
|
|
477
|
+
p_total = G * gs
|
|
478
|
+
w_feat = w[:p_total] # handle augmented intercept
|
|
479
|
+
|
|
480
|
+
# Gather into (G, gs) matrix
|
|
481
|
+
if self._is_contiguous:
|
|
482
|
+
w_mat = w_feat.reshape(G, gs)
|
|
483
|
+
else:
|
|
484
|
+
w_mat = w_feat[self._flat_indices].reshape(G, gs)
|
|
485
|
+
|
|
486
|
+
sqrt_pg_arr = self._get_sqrt_pg(xp, w)
|
|
487
|
+
|
|
488
|
+
# Torch compiled fast path
|
|
489
|
+
if xp.__name__ == "torch":
|
|
490
|
+
compiled_fn = _get_group_lasso_torch_compiled_equal()
|
|
491
|
+
if compiled_fn is not None:
|
|
492
|
+
scaled_flat = compiled_fn(w_mat, sqrt_pg_arr, self.alpha, step)
|
|
493
|
+
result = w.clone()
|
|
494
|
+
if self._is_contiguous:
|
|
495
|
+
result[:p_total] = scaled_flat
|
|
496
|
+
else:
|
|
497
|
+
result[self._flat_indices] = scaled_flat
|
|
498
|
+
return result
|
|
499
|
+
|
|
500
|
+
# Generic vectorized path
|
|
501
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
502
|
+
thresh = self.alpha * sqrt_pg_arr * step
|
|
503
|
+
scale = xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
|
|
504
|
+
scaled_flat = (w_mat * scale[:, None]).reshape(-1)
|
|
505
|
+
|
|
506
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
507
|
+
if self._is_contiguous:
|
|
508
|
+
result[:p_total] = scaled_flat
|
|
509
|
+
else:
|
|
510
|
+
result[self._flat_indices] = scaled_flat
|
|
511
|
+
return result
|
|
512
|
+
|
|
513
|
+
def _proximal_padded(self, w, step, xp, G, max_sz):
|
|
514
|
+
"""General path: pad unequal groups, compute norms vectorized."""
|
|
515
|
+
p_total = int(self._group_sizes.sum())
|
|
516
|
+
w_feat = w[:p_total] # handle augmented intercept
|
|
517
|
+
|
|
518
|
+
# Build padded matrix (G, max_sz) via fancy indexing — 1 kernel launch
|
|
519
|
+
padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
|
|
520
|
+
row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
|
|
521
|
+
col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
|
|
522
|
+
if self._is_contiguous:
|
|
523
|
+
padded[row_idx_dev, col_idx_dev] = w_feat
|
|
524
|
+
else:
|
|
525
|
+
flat_idx_dev = self._get_flat_indices(xp, w)
|
|
526
|
+
padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
|
|
527
|
+
|
|
528
|
+
# Vectorized norms
|
|
529
|
+
norms = _vector_norm(padded, xp, dim=1)
|
|
530
|
+
|
|
531
|
+
sqrt_pg_arr = self._get_sqrt_pg(xp, w)
|
|
532
|
+
thresh = self.alpha * sqrt_pg_arr * step
|
|
533
|
+
scale = xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
|
|
534
|
+
|
|
535
|
+
# Apply scaling
|
|
536
|
+
padded_scaled = padded * scale[:, None]
|
|
537
|
+
|
|
538
|
+
# Scatter back via fancy indexing — 1 kernel launch
|
|
539
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
540
|
+
if self._is_contiguous:
|
|
541
|
+
result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
|
|
542
|
+
else:
|
|
543
|
+
result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
|
|
544
|
+
return result
|
|
545
|
+
|
|
546
|
+
# ----------------------------------------------------------------
|
|
547
|
+
|
|
548
|
+
def get_params(self) -> dict:
|
|
549
|
+
params = super().get_params()
|
|
550
|
+
params.update({
|
|
551
|
+
"alpha": self.alpha,
|
|
552
|
+
"n_groups": self._n_groups if self._group_indices else 0,
|
|
553
|
+
})
|
|
554
|
+
return params
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class AdaptiveGroupLassoPenalty(GroupLassoPenalty):
|
|
558
|
+
"""Group Lasso with per-group weights for LLA linearization of group SCAD/MCP.
|
|
559
|
+
|
|
560
|
+
The penalty is:
|
|
561
|
+
P(w) = alpha * sum_g weights_g * sqrt(p_g) * ||w_g||_2
|
|
562
|
+
|
|
563
|
+
where weights_g are per-group LLA weights.
|
|
564
|
+
"""
|
|
565
|
+
|
|
566
|
+
name = "adaptive_group_lasso"
|
|
567
|
+
|
|
568
|
+
def __init__(self, groups, alpha=1.0, weights=None):
|
|
569
|
+
super().__init__(alpha=alpha, groups=groups)
|
|
570
|
+
# weights: per-group weight array, shape (n_groups,)
|
|
571
|
+
# None = uniform (same as GroupLasso)
|
|
572
|
+
self._group_weights = weights
|
|
573
|
+
|
|
574
|
+
def set_weights(self, weights):
|
|
575
|
+
"""Update per-group weights (numpy array, shape (n_groups,))."""
|
|
576
|
+
self._group_weights = weights
|
|
577
|
+
# Invalidate cached device tensors
|
|
578
|
+
self._group_weights_torch = None
|
|
579
|
+
self._group_weights_cupy = None
|
|
580
|
+
|
|
581
|
+
def _get_group_weights(self, xp, w):
|
|
582
|
+
"""Cached device tensor for _group_weights."""
|
|
583
|
+
if self._group_weights is None:
|
|
584
|
+
return None
|
|
585
|
+
if xp.__name__ == "torch":
|
|
586
|
+
if not hasattr(self, '_group_weights_torch') or self._group_weights_torch is None:
|
|
587
|
+
self._group_weights_torch = _to_backend_array(self._group_weights, xp, w)
|
|
588
|
+
return self._group_weights_torch
|
|
589
|
+
else:
|
|
590
|
+
if not hasattr(self, '_group_weights_cupy') or self._group_weights_cupy is None:
|
|
591
|
+
self._group_weights_cupy = _to_backend_array(self._group_weights, xp, w)
|
|
592
|
+
return self._group_weights_cupy
|
|
593
|
+
|
|
594
|
+
def _proximal_loop(self, w, step, xp):
|
|
595
|
+
"""Per-group serial loop with per-group weights."""
|
|
596
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
597
|
+
for g, idx in enumerate(self._group_indices):
|
|
598
|
+
w_g = w[idx]
|
|
599
|
+
norm = float(xp.linalg.norm(w_g))
|
|
600
|
+
wg = float(self._group_weights[g]) if self._group_weights is not None else 1.0
|
|
601
|
+
thresh = self.alpha * wg * self._sqrt_pg[g] * step
|
|
602
|
+
if norm > thresh:
|
|
603
|
+
result[idx] = w_g * (1.0 - thresh / norm)
|
|
604
|
+
else:
|
|
605
|
+
result[idx] = 0.0
|
|
606
|
+
return result
|
|
607
|
+
|
|
608
|
+
def _proximal_equal(self, w, step, xp, G, gs):
|
|
609
|
+
"""Fast path: all groups equal size, vectorized norm + scale with weights."""
|
|
610
|
+
p_total = G * gs
|
|
611
|
+
w_feat = w[:p_total] # handle augmented intercept
|
|
612
|
+
|
|
613
|
+
if self._is_contiguous:
|
|
614
|
+
w_mat = w_feat.reshape(G, gs)
|
|
615
|
+
else:
|
|
616
|
+
w_mat = w_feat[self._flat_indices].reshape(G, gs)
|
|
617
|
+
|
|
618
|
+
sqrt_pg_arr = self._get_sqrt_pg(xp, w)
|
|
619
|
+
weights_arr = self._get_group_weights(xp, w)
|
|
620
|
+
if weights_arr is None:
|
|
621
|
+
weights_arr = xp.ones(G, dtype=w.dtype)
|
|
622
|
+
if hasattr(w, 'device'):
|
|
623
|
+
weights_arr = weights_arr.to(device=w.device)
|
|
624
|
+
|
|
625
|
+
norms = _vector_norm(w_mat, xp, dim=1)
|
|
626
|
+
thresh = self.alpha * weights_arr * sqrt_pg_arr * step
|
|
627
|
+
scale = xp.clamp(1.0 - thresh / (norms + 1e-12), 0.0, None) if xp.__name__ == "torch" else xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
|
|
628
|
+
scaled_flat = (w_mat * scale[:, None]).reshape(-1)
|
|
629
|
+
|
|
630
|
+
result = w.clone() if hasattr(w, 'clone') else w.copy()
|
|
631
|
+
if self._is_contiguous:
|
|
632
|
+
result[:p_total] = scaled_flat
|
|
633
|
+
else:
|
|
634
|
+
result[self._flat_indices] = scaled_flat
|
|
635
|
+
return result
|
|
636
|
+
|
|
637
|
+
def _proximal_padded(self, w, step, xp, G, max_sz):
|
|
638
|
+
"""General path: pad unequal groups with per-group weights (fancy indexing)."""
|
|
639
|
+
p_total = int(self._group_sizes.sum())
|
|
640
|
+
w_feat = w[:p_total] # handle augmented intercept
|
|
641
|
+
|
|
642
|
+
padded = _backend_zeros((G, max_sz), xp, dtype=w.dtype, ref_arr=w)
|
|
643
|
+
row_idx_dev = self._get_cached('_padded_row_idx', xp, w)
|
|
644
|
+
col_idx_dev = self._get_cached('_padded_col_idx', xp, w)
|
|
645
|
+
if self._is_contiguous:
|
|
646
|
+
padded[row_idx_dev, col_idx_dev] = w_feat
|
|
647
|
+
else:
|
|
648
|
+
flat_idx_dev = self._get_flat_indices(xp, w)
|
|
649
|
+
padded[row_idx_dev, col_idx_dev] = w_feat[flat_idx_dev]
|
|
650
|
+
|
|
651
|
+
norms = _vector_norm(padded, xp, dim=1)
|
|
652
|
+
sqrt_pg_arr = self._get_sqrt_pg(xp, w)
|
|
653
|
+
weights_arr = self._get_group_weights(xp, w)
|
|
654
|
+
if weights_arr is None:
|
|
655
|
+
weights_arr = xp.ones(G, dtype=w.dtype)
|
|
656
|
+
if hasattr(w, 'device'):
|
|
657
|
+
weights_arr = weights_arr.to(device=w.device)
|
|
658
|
+
|
|
659
|
+
thresh = self.alpha * weights_arr * sqrt_pg_arr * step
|
|
660
|
+
if xp.__name__ == "torch":
|
|
661
|
+
scale = xp.clamp(1.0 - thresh / (norms + 1e-12), min=0.0)
|
|
662
|
+
else:
|
|
663
|
+
scale = xp.clip(1.0 - thresh / (norms + 1e-12), 0.0, None)
|
|
664
|
+
padded_scaled = padded * scale[:, None]
|
|
665
|
+
|
|
666
|
+
result = w.copy() if hasattr(w, 'copy') else w.clone()
|
|
667
|
+
if self._is_contiguous:
|
|
668
|
+
result[:p_total] = padded_scaled[row_idx_dev, col_idx_dev]
|
|
669
|
+
else:
|
|
670
|
+
result[flat_idx_dev] = padded_scaled[row_idx_dev, col_idx_dev]
|
|
671
|
+
return result
|
|
672
|
+
|
|
673
|
+
def get_params(self) -> dict:
|
|
674
|
+
params = super().get_params()
|
|
675
|
+
params.update({
|
|
676
|
+
"weights": self._group_weights,
|
|
677
|
+
})
|
|
678
|
+
return params
|