statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generalized Additive Model (GAM) with GPU support.
|
|
3
|
+
|
|
4
|
+
Implements GAM using penalized B-splines with automatic smoothing
|
|
5
|
+
parameter selection via Generalized Cross-Validation (GCV).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
__all__ = ["GAM"]
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from typing import Optional, Union
|
|
14
|
+
|
|
15
|
+
from statgpu._base import BaseEstimator
|
|
16
|
+
from statgpu._config import Device
|
|
17
|
+
from statgpu.backends import _torch_dev, _to_numpy, xp_zeros, xp_ones, xp_asarray, xp_copy
|
|
18
|
+
from statgpu.nonparametric.splines._bspline_basis import bspline_basis
|
|
19
|
+
from statgpu.nonparametric.splines._penalized import (
|
|
20
|
+
difference_penalty,
|
|
21
|
+
penalized_ls,
|
|
22
|
+
generalized_cross_validation,
|
|
23
|
+
select_lambda_gcv,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GAM(BaseEstimator):
|
|
28
|
+
"""
|
|
29
|
+
Generalized Additive Model (GAM) using penalized B-splines.
|
|
30
|
+
|
|
31
|
+
Fits a smooth function for each feature using B-spline basis with
|
|
32
|
+
a difference penalty for smoothness. Smoothing parameters can be
|
|
33
|
+
specified or automatically selected via GCV.
|
|
34
|
+
|
|
35
|
+
The model is: y = alpha + sum_j f_j(x_j) + epsilon
|
|
36
|
+
|
|
37
|
+
where each f_j is represented as a penalized B-spline.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
n_splines : int, default=20
|
|
42
|
+
Number of basis functions per feature (before penalty reduction).
|
|
43
|
+
degree : int, default=3
|
|
44
|
+
Degree of B-spline basis (3 = cubic).
|
|
45
|
+
lam : float or None, default=None
|
|
46
|
+
Smoothing parameter. If None, automatically selected via GCV.
|
|
47
|
+
penalty_order : int, default=2
|
|
48
|
+
Order of difference penalty (2 = second differences).
|
|
49
|
+
device : str or Device, default='auto'
|
|
50
|
+
Computation device: 'cpu', 'cuda', or 'auto'.
|
|
51
|
+
n_jobs : int or None, default=None
|
|
52
|
+
Number of parallel jobs.
|
|
53
|
+
|
|
54
|
+
Attributes
|
|
55
|
+
----------
|
|
56
|
+
coef_ : array, shape (n_features * n_splines + 1,)
|
|
57
|
+
Fitted coefficients (intercept + spline coefficients for each feature).
|
|
58
|
+
intercept_ : float
|
|
59
|
+
Intercept term.
|
|
60
|
+
edf_ : float
|
|
61
|
+
Total effective degrees of freedom.
|
|
62
|
+
gcv_score_ : float
|
|
63
|
+
GCV score (if lam was auto-selected).
|
|
64
|
+
lam_ : float
|
|
65
|
+
Smoothing parameter used (after auto-selection if applicable).
|
|
66
|
+
knots_ : list of arrays
|
|
67
|
+
Interior knots for each feature.
|
|
68
|
+
n_features_ : int
|
|
69
|
+
Number of features in training data.
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> import numpy as np
|
|
74
|
+
>>> from statgpu.semiparametric import GAM
|
|
75
|
+
>>> X = np.random.randn(100, 3)
|
|
76
|
+
>>> y = np.sin(X[:, 0]) + 0.5 * X[:, 1] ** 2 + np.random.randn(100) * 0.1
|
|
77
|
+
>>> gam = GAM(n_splines=15, lam=1.0)
|
|
78
|
+
>>> gam.fit(X, y)
|
|
79
|
+
>>> y_pred = gam.predict(X)
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
n_splines: int = 20,
|
|
85
|
+
degree: int = 3,
|
|
86
|
+
lam: Optional[float] = None,
|
|
87
|
+
penalty_order: int = 2,
|
|
88
|
+
device: Union[str, Device] = Device.AUTO,
|
|
89
|
+
n_jobs: Optional[int] = None,
|
|
90
|
+
):
|
|
91
|
+
super().__init__(device=device, n_jobs=n_jobs)
|
|
92
|
+
self.n_splines = n_splines
|
|
93
|
+
self.degree = degree
|
|
94
|
+
self.lam = lam
|
|
95
|
+
self.penalty_order = penalty_order
|
|
96
|
+
|
|
97
|
+
# Fitted attributes
|
|
98
|
+
self.coef_ = None
|
|
99
|
+
self.intercept_ = None
|
|
100
|
+
self.edf_ = None
|
|
101
|
+
self.gcv_score_ = None
|
|
102
|
+
self.lam_ = None
|
|
103
|
+
self.knots_ = None
|
|
104
|
+
self.n_features_ = None
|
|
105
|
+
|
|
106
|
+
def _get_xp(self):
|
|
107
|
+
"""Get the array module for computation.
|
|
108
|
+
|
|
109
|
+
Returns ``backend.xp`` (the raw array module) so callers can use
|
|
110
|
+
``xp.asarray`` etc. directly. Delegates to the parent's
|
|
111
|
+
``_get_backend()`` which returns a ``BackendBase`` with correct
|
|
112
|
+
device/dtype handling.
|
|
113
|
+
"""
|
|
114
|
+
backend = super()._get_backend(backend='auto')
|
|
115
|
+
return backend.xp
|
|
116
|
+
|
|
117
|
+
def _create_knots(self, x_col, n_splines, xp):
|
|
118
|
+
"""
|
|
119
|
+
Create interior knots for a feature using quantiles.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
x_col : array, shape (n,)
|
|
124
|
+
Feature values.
|
|
125
|
+
n_splines : int
|
|
126
|
+
Number of basis functions.
|
|
127
|
+
xp : module
|
|
128
|
+
Array module.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
knots : array, shape (n_splines - degree - 1,)
|
|
133
|
+
Interior knots.
|
|
134
|
+
"""
|
|
135
|
+
# Use quantiles for knot placement
|
|
136
|
+
# Exclude boundary knots (they'll be added by bspline_basis)
|
|
137
|
+
n_interior = n_splines - self.degree - 1
|
|
138
|
+
|
|
139
|
+
if n_interior <= 0:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"n_splines ({n_splines}) must be greater than degree ({self.degree}) + 1"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Use percentiles from 0 to 100, excluding boundaries
|
|
145
|
+
percentiles = np.linspace(0, 100, n_interior + 2)[1:-1]
|
|
146
|
+
|
|
147
|
+
# Convert to numpy for percentile computation
|
|
148
|
+
x_np = _to_numpy(x_col)
|
|
149
|
+
|
|
150
|
+
knots = np.percentile(x_np, percentiles)
|
|
151
|
+
|
|
152
|
+
# Remove duplicate knots (can happen with discrete data)
|
|
153
|
+
knots = np.unique(knots)
|
|
154
|
+
|
|
155
|
+
return xp_asarray(knots, dtype=xp.float64, xp=xp, ref_arr=x_col)
|
|
156
|
+
|
|
157
|
+
def _build_basis(self, X, xp):
|
|
158
|
+
"""
|
|
159
|
+
Build combined basis matrix for all features.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
X : array, shape (n, p)
|
|
164
|
+
Input features.
|
|
165
|
+
xp : module
|
|
166
|
+
Array module.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
B : array, shape (n, 1 + sum(n_basis_j))
|
|
171
|
+
Combined basis matrix with intercept column.
|
|
172
|
+
penalty : array, shape (1 + sum(n_basis_j), 1 + sum(n_basis_j))
|
|
173
|
+
Block-diagonal penalty matrix (intercept not penalized).
|
|
174
|
+
"""
|
|
175
|
+
n, p = X.shape
|
|
176
|
+
basis_blocks = []
|
|
177
|
+
penalty_blocks = []
|
|
178
|
+
total_basis = 0
|
|
179
|
+
|
|
180
|
+
for j in range(p):
|
|
181
|
+
x_col = X[:, j]
|
|
182
|
+
|
|
183
|
+
# Create knots for this feature
|
|
184
|
+
knots_j = self._create_knots(x_col, self.n_splines, xp)
|
|
185
|
+
self.knots_.append(knots_j)
|
|
186
|
+
|
|
187
|
+
# Store training boundary for prediction
|
|
188
|
+
self._boundary_lo_.append(float(xp.min(x_col)))
|
|
189
|
+
self._boundary_hi_.append(float(xp.max(x_col)))
|
|
190
|
+
|
|
191
|
+
# Build B-spline basis
|
|
192
|
+
B_j = bspline_basis(x_col, knots_j, degree=self.degree, xp=xp)
|
|
193
|
+
n_basis_j = B_j.shape[1]
|
|
194
|
+
|
|
195
|
+
# Build penalty matrix
|
|
196
|
+
S_j = difference_penalty(self.penalty_order, n_basis_j, xp)
|
|
197
|
+
|
|
198
|
+
basis_blocks.append(B_j)
|
|
199
|
+
penalty_blocks.append(S_j)
|
|
200
|
+
total_basis += n_basis_j
|
|
201
|
+
|
|
202
|
+
# Combine basis matrices: [1, B_1, B_2, ..., B_p]
|
|
203
|
+
intercept_col = xp_ones((n, 1), xp.float64, xp, X)
|
|
204
|
+
B = xp.hstack([intercept_col] + basis_blocks)
|
|
205
|
+
|
|
206
|
+
# Block-diagonal penalty with intercept dimension (not penalized)
|
|
207
|
+
# Size: (1 + total_basis, 1 + total_basis) to match B
|
|
208
|
+
full_size = 1 + total_basis
|
|
209
|
+
penalty = xp_zeros((full_size, full_size), xp.float64, xp, X)
|
|
210
|
+
offset = 1 # Skip intercept (row/col 0)
|
|
211
|
+
for S_j in penalty_blocks:
|
|
212
|
+
n_j = S_j.shape[0]
|
|
213
|
+
penalty[offset:offset + n_j, offset:offset + n_j] = S_j
|
|
214
|
+
offset += n_j
|
|
215
|
+
|
|
216
|
+
return B, penalty
|
|
217
|
+
|
|
218
|
+
def fit(self, X, y=None, **fit_params):
|
|
219
|
+
"""
|
|
220
|
+
Fit the GAM model.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
X : array-like, shape (n_samples, n_features)
|
|
225
|
+
Training data.
|
|
226
|
+
y : array-like, shape (n_samples,)
|
|
227
|
+
Target values.
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
self : GAM
|
|
232
|
+
Fitted model.
|
|
233
|
+
"""
|
|
234
|
+
xp = self._get_xp()
|
|
235
|
+
|
|
236
|
+
# Convert to arrays on the correct device
|
|
237
|
+
# For torch backend, ensure arrays land on CUDA (not CPU)
|
|
238
|
+
_ref = None
|
|
239
|
+
if xp.__name__ == "torch":
|
|
240
|
+
import torch
|
|
241
|
+
_dev = getattr(self, 'device', None)
|
|
242
|
+
if _dev is not None and hasattr(_dev, 'value') and _dev.value in ('cuda', 'torch'):
|
|
243
|
+
_ref = torch.empty(0, device="cuda")
|
|
244
|
+
elif torch.cuda.is_available():
|
|
245
|
+
_ref = torch.empty(0, device="cuda")
|
|
246
|
+
X = xp_asarray(X, dtype=xp.float64, xp=xp, ref_arr=_ref)
|
|
247
|
+
y = xp_asarray(y, dtype=xp.float64, xp=xp, ref_arr=X).ravel()
|
|
248
|
+
|
|
249
|
+
n, p = X.shape
|
|
250
|
+
self.n_features_ = p
|
|
251
|
+
self.knots_ = []
|
|
252
|
+
self._boundary_lo_ = []
|
|
253
|
+
self._boundary_hi_ = []
|
|
254
|
+
|
|
255
|
+
# Build basis matrix and penalty
|
|
256
|
+
B, penalty = self._build_basis(X, xp)
|
|
257
|
+
|
|
258
|
+
# Center spline basis columns (not intercept) so the intercept
|
|
259
|
+
# captures the overall mean of y. This makes the intercept
|
|
260
|
+
# identifiable even though spline basis can represent constants.
|
|
261
|
+
self._basis_mean_ = xp.mean(B[:, 1:], axis=0)
|
|
262
|
+
B_centered = xp_copy(B)
|
|
263
|
+
B_centered[:, 1:] = B[:, 1:] - self._basis_mean_
|
|
264
|
+
|
|
265
|
+
# Select smoothing parameter
|
|
266
|
+
if self.lam is None:
|
|
267
|
+
# Auto-select via GCV
|
|
268
|
+
best_lam, gcv_scores = select_lambda_gcv(
|
|
269
|
+
B_centered, y, penalty, xp=xp
|
|
270
|
+
)
|
|
271
|
+
self.lam_ = best_lam
|
|
272
|
+
self.gcv_score_ = float(xp.min(gcv_scores))
|
|
273
|
+
else:
|
|
274
|
+
self.lam_ = self.lam
|
|
275
|
+
self.gcv_score_ = None
|
|
276
|
+
|
|
277
|
+
# Fit the model with centered basis
|
|
278
|
+
beta, edf = penalized_ls(B_centered, y, penalty, self.lam_, xp)
|
|
279
|
+
|
|
280
|
+
# Store results
|
|
281
|
+
self.coef_ = beta
|
|
282
|
+
self.intercept_ = float(beta[0])
|
|
283
|
+
self.edf_ = float(edf) if not isinstance(edf, float) else edf
|
|
284
|
+
self._fitted = True
|
|
285
|
+
|
|
286
|
+
# Store training data info for prediction
|
|
287
|
+
self._xp = xp
|
|
288
|
+
self._xp_asarray_ref_ = X # device reference for xp_asarray
|
|
289
|
+
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
def predict(self, X):
|
|
293
|
+
"""
|
|
294
|
+
Predict using the fitted GAM model.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
X : array-like, shape (n_samples, n_features)
|
|
299
|
+
Input features.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
y_pred : array, shape (n_samples,)
|
|
304
|
+
Predicted values.
|
|
305
|
+
"""
|
|
306
|
+
self._check_is_fitted()
|
|
307
|
+
|
|
308
|
+
# Re-resolve backend to handle device changes since fit()
|
|
309
|
+
xp = self._get_xp()
|
|
310
|
+
X = xp_asarray(X, dtype=xp.float64, xp=xp, ref_arr=self._xp_asarray_ref_)
|
|
311
|
+
|
|
312
|
+
n, p = X.shape
|
|
313
|
+
if p != self.n_features_:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"X has {p} features, but model was fitted with {self.n_features_}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Build basis for prediction (use training boundaries to avoid
|
|
319
|
+
# "knots must be strictly within boundary" errors on small batches)
|
|
320
|
+
basis_blocks = []
|
|
321
|
+
for j in range(p):
|
|
322
|
+
x_col = X[:, j]
|
|
323
|
+
knots_j = self.knots_[j]
|
|
324
|
+
B_j = bspline_basis(
|
|
325
|
+
x_col, knots_j, degree=self.degree, xp=xp,
|
|
326
|
+
boundary_lo=self._boundary_lo_[j],
|
|
327
|
+
boundary_hi=self._boundary_hi_[j],
|
|
328
|
+
)
|
|
329
|
+
basis_blocks.append(B_j)
|
|
330
|
+
|
|
331
|
+
# Combine: [1, B_1, B_2, ..., B_p]
|
|
332
|
+
intercept_col = xp_ones((n, 1), xp.float64, xp, X)
|
|
333
|
+
B = xp.hstack([intercept_col] + basis_blocks)
|
|
334
|
+
|
|
335
|
+
# Apply same centering as in fit
|
|
336
|
+
B[:, 1:] = B[:, 1:] - self._basis_mean_
|
|
337
|
+
|
|
338
|
+
# Predict
|
|
339
|
+
y_pred = B @ self.coef_
|
|
340
|
+
|
|
341
|
+
return _to_numpy(y_pred)
|
|
342
|
+
|
|
343
|
+
def summary(self):
|
|
344
|
+
"""
|
|
345
|
+
Print a summary of the fitted GAM model.
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
summary_dict : dict
|
|
350
|
+
Dictionary containing model summary information.
|
|
351
|
+
"""
|
|
352
|
+
self._check_is_fitted()
|
|
353
|
+
|
|
354
|
+
summary_dict = {
|
|
355
|
+
'n_features': self.n_features_,
|
|
356
|
+
'n_splines_per_feature': self.n_splines,
|
|
357
|
+
'spline_degree': self.degree,
|
|
358
|
+
'penalty_order': self.penalty_order,
|
|
359
|
+
'smoothing_parameter': self.lam_,
|
|
360
|
+
'effective_df': self.edf_,
|
|
361
|
+
'intercept': self.intercept_,
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
if self.gcv_score_ is not None:
|
|
365
|
+
summary_dict['gcv_score'] = self.gcv_score_
|
|
366
|
+
|
|
367
|
+
print("=" * 50)
|
|
368
|
+
print("GAM Model Summary")
|
|
369
|
+
print("=" * 50)
|
|
370
|
+
print(f"Number of features: {self.n_features_}")
|
|
371
|
+
print(f"B-splines per feature: {self.n_splines}")
|
|
372
|
+
print(f"Spline degree: {self.degree}")
|
|
373
|
+
print(f"Penalty order: {self.penalty_order}")
|
|
374
|
+
print(f"Smoothing parameter (lambda): {self.lam_:.6g}")
|
|
375
|
+
print(f"Effective degrees of freedom: {self.edf_:.2f}")
|
|
376
|
+
print(f"Intercept: {self.intercept_:.6f}")
|
|
377
|
+
if self.gcv_score_ is not None:
|
|
378
|
+
print(f"GCV score: {self.gcv_score_:.6f}")
|
|
379
|
+
print("=" * 50)
|
|
380
|
+
|
|
381
|
+
return summary_dict
|
|
382
|
+
|
|
383
|
+
def get_params(self, deep=True):
|
|
384
|
+
"""Get parameters for this estimator."""
|
|
385
|
+
params = super().get_params(deep)
|
|
386
|
+
params.update({
|
|
387
|
+
'n_splines': self.n_splines,
|
|
388
|
+
'degree': self.degree,
|
|
389
|
+
'lam': self.lam,
|
|
390
|
+
'penalty_order': self.penalty_order,
|
|
391
|
+
})
|
|
392
|
+
return params
|
|
393
|
+
|
|
394
|
+
def set_params(self, **params):
|
|
395
|
+
"""Set parameters for this estimator."""
|
|
396
|
+
for key, value in params.items():
|
|
397
|
+
if key in ('n_splines', 'degree', 'lam', 'penalty_order'):
|
|
398
|
+
setattr(self, key, value)
|
|
399
|
+
else:
|
|
400
|
+
super().set_params(**{key: value})
|
|
401
|
+
return self
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Generic optimization solvers for penalized loss functions.
|
|
2
|
+
|
|
3
|
+
These solvers work with any loss that implements the GLMLoss interface
|
|
4
|
+
(value, gradient, fused_value_and_gradient, lipschitz, hessian, preprocess)
|
|
5
|
+
and any penalty with a proximal operator.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"fista_solver",
|
|
10
|
+
"fista_bb_solver",
|
|
11
|
+
"fista_lla_path",
|
|
12
|
+
"newton_solver",
|
|
13
|
+
"lbfgs_solver",
|
|
14
|
+
"admm_solver",
|
|
15
|
+
"ConvergenceWarning",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
from ._convergence import ConvergenceWarning
|
|
19
|
+
from ._fista import fista_solver
|
|
20
|
+
from ._fista_bb import fista_bb_solver
|
|
21
|
+
from ._fista_lla import fista_lla_path
|
|
22
|
+
from ._newton import newton_solver
|
|
23
|
+
from ._lbfgs import lbfgs_solver
|
|
24
|
+
from ._admm import admm_solver
|
statgpu/solvers/_admm.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""ADMM solver for penalized GLM optimization.
|
|
2
|
+
|
|
3
|
+
Reformulates min_w f(Xw; y) + p(w) as a consensus ADMM problem and solves
|
|
4
|
+
via alternating direction method of multipliers. The w-update (smooth
|
|
5
|
+
subproblem) uses either a direct Cholesky solve (for squared-error loss with
|
|
6
|
+
moderate dimensionality) or Nesterov-accelerated gradient descent. The z-update
|
|
7
|
+
reuses the penalty proximal operator and is element-wise / GPU-friendly.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from statgpu.backends import _resolve_backend
|
|
17
|
+
from statgpu.backends._array_ops import (
|
|
18
|
+
_abs_sum_dev,
|
|
19
|
+
_copy_arr,
|
|
20
|
+
_device_leq,
|
|
21
|
+
_norm2_dev,
|
|
22
|
+
_sync_scalars,
|
|
23
|
+
_zeros,
|
|
24
|
+
_zeros_like,
|
|
25
|
+
)
|
|
26
|
+
from ._convergence import ConvergenceWarning
|
|
27
|
+
from ._utils import (
|
|
28
|
+
_nesterov_momentum,
|
|
29
|
+
_validate_uniform_sample_weight,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__all__ = ["admm_solver"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def admm_solver(
|
|
36
|
+
loss: "GLMLoss",
|
|
37
|
+
penalty: "Penalty | None",
|
|
38
|
+
X,
|
|
39
|
+
y,
|
|
40
|
+
max_iter: int = 200,
|
|
41
|
+
tol: float = 1e-4,
|
|
42
|
+
rho: float = 1.0,
|
|
43
|
+
adaptive_rho: bool = True,
|
|
44
|
+
cg_max_iter: int = 30,
|
|
45
|
+
cg_tol: float = 1e-6,
|
|
46
|
+
init_coef=None,
|
|
47
|
+
sample_weight=None,
|
|
48
|
+
) -> tuple:
|
|
49
|
+
"""ADMM solver for penalized GLM optimization.
|
|
50
|
+
|
|
51
|
+
Reformulates min_w f(Xw; y) + p(w) as:
|
|
52
|
+
min_{w,z} f(Xw; y) + p(z) s.t. w = z
|
|
53
|
+
|
|
54
|
+
and solves via the alternating direction method of multipliers:
|
|
55
|
+
w^{k+1} = argmin_w f(Xw; y) + (rho/2)||w - z^k + u^k||^2
|
|
56
|
+
z^{k+1} = prox_{p/rho}(w^{k+1} + u^k)
|
|
57
|
+
u^{k+1} = u^k + w^{k+1} - z^{k+1}
|
|
58
|
+
|
|
59
|
+
The w-update is a smooth, strongly convex problem solved via conjugate
|
|
60
|
+
gradient. The z-update reuses penalty.proximal(). Both are GPU-friendly:
|
|
61
|
+
w-update uses dense matmuls (cuBLAS), z-update is element-wise.
|
|
62
|
+
|
|
63
|
+
Supports numpy / cupy / torch backends via auto-detection of X.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
loss : GLMLoss
|
|
68
|
+
penalty : Penalty
|
|
69
|
+
X, y : arrays
|
|
70
|
+
max_iter : int
|
|
71
|
+
Maximum ADMM outer iterations.
|
|
72
|
+
tol : float
|
|
73
|
+
Convergence tolerance for primal/dual residuals.
|
|
74
|
+
rho : float
|
|
75
|
+
Augmented Lagrangian penalty parameter.
|
|
76
|
+
adaptive_rho : bool
|
|
77
|
+
Adapt rho based on primal/dual residual balance.
|
|
78
|
+
cg_max_iter : int
|
|
79
|
+
Maximum CG iterations for w-update subproblem.
|
|
80
|
+
cg_tol : float
|
|
81
|
+
CG convergence tolerance.
|
|
82
|
+
init_coef : array, optional
|
|
83
|
+
Initial coefficients.
|
|
84
|
+
sample_weight : array, optional
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
coef : array, n_iter : int
|
|
89
|
+
"""
|
|
90
|
+
backend = _resolve_backend("auto", X)
|
|
91
|
+
X_proc, y_proc = loss.preprocess(X, y)
|
|
92
|
+
n_features = X_proc.shape[1]
|
|
93
|
+
|
|
94
|
+
# Initialize
|
|
95
|
+
if init_coef is not None:
|
|
96
|
+
w = (
|
|
97
|
+
_copy_arr(init_coef)
|
|
98
|
+
if hasattr(init_coef, "copy") or hasattr(init_coef, "clone")
|
|
99
|
+
else np.array(init_coef).copy()
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
w = _zeros(n_features, backend, ref_tensor=X)
|
|
103
|
+
|
|
104
|
+
z = _copy_arr(w)
|
|
105
|
+
u = _zeros_like(w)
|
|
106
|
+
|
|
107
|
+
if sample_weight is not None:
|
|
108
|
+
_validate_uniform_sample_weight(sample_weight, X_proc.shape[0], "admm_solver")
|
|
109
|
+
|
|
110
|
+
def _grad_w(w_vec, z_cur, u_cur):
|
|
111
|
+
"""Gradient of f(w) + (rho/2)||w - z_cur + u_cur||^2 w.r.t. w."""
|
|
112
|
+
g = loss.gradient(X_proc, y_proc, w_vec, sample_weight=sample_weight)
|
|
113
|
+
g = g + rho * (w_vec - z_cur + u_cur)
|
|
114
|
+
return g
|
|
115
|
+
|
|
116
|
+
# Detect if loss supports Cholesky (constant Hessian, e.g. squared_error).
|
|
117
|
+
# For GLM losses, use Nesterov-accelerated gradient descent.
|
|
118
|
+
# When using Cholesky we pin rho (disable adaptive_rho) because the
|
|
119
|
+
# precomputed _A_mat = XtX/n + rho*I would become stale if rho changed.
|
|
120
|
+
use_cholesky = getattr(loss, '_supports_cholesky', False) and n_features <= 2000
|
|
121
|
+
if use_cholesky:
|
|
122
|
+
adaptive_rho = False
|
|
123
|
+
|
|
124
|
+
if use_cholesky:
|
|
125
|
+
_hess_const = loss.hessian(X_proc, y_proc, w) # XtX / n
|
|
126
|
+
_A_mat = _hess_const
|
|
127
|
+
_cholesky_ok = False
|
|
128
|
+
if hasattr(_hess_const, 'shape'):
|
|
129
|
+
try:
|
|
130
|
+
if backend == "numpy":
|
|
131
|
+
_A_mat = _hess_const + rho * np.eye(n_features, dtype=_hess_const.dtype)
|
|
132
|
+
_L = np.linalg.cholesky(_A_mat)
|
|
133
|
+
elif backend == "cupy":
|
|
134
|
+
import cupy as cp
|
|
135
|
+
_A_mat = _hess_const + rho * cp.eye(n_features, dtype=_hess_const.dtype)
|
|
136
|
+
_L = cp.linalg.cholesky(_A_mat)
|
|
137
|
+
else:
|
|
138
|
+
import torch
|
|
139
|
+
_A_mat = _hess_const + rho * torch.eye(n_features, dtype=_hess_const.dtype, device=_hess_const.device)
|
|
140
|
+
_L = torch.linalg.cholesky(_A_mat)
|
|
141
|
+
_cholesky_ok = True
|
|
142
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
143
|
+
# Matrix not positive-definite (numerical issues, collinear features)
|
|
144
|
+
# Fall back to CG solver below
|
|
145
|
+
_cholesky_ok = False
|
|
146
|
+
if not _cholesky_ok:
|
|
147
|
+
use_cholesky = False
|
|
148
|
+
|
|
149
|
+
# Precompute -grad_f(0) = Xty/n for squared_error (the constant part)
|
|
150
|
+
_zero_coef = _zeros_like(w)
|
|
151
|
+
_neg_grad_zero = -loss.gradient(X_proc, y_proc, _zero_coef, sample_weight=sample_weight) # Xty/n
|
|
152
|
+
|
|
153
|
+
else:
|
|
154
|
+
# Gradient descent step: 1/(L_f + rho)
|
|
155
|
+
L_f = loss.lipschitz(X_proc, w, y=y_proc)
|
|
156
|
+
if L_f <= 0:
|
|
157
|
+
L_f = 1.0
|
|
158
|
+
lr_sub = 1.0 / (L_f + rho + 1e-8)
|
|
159
|
+
iteration = -1 # default if max_iter=0
|
|
160
|
+
|
|
161
|
+
for iteration in range(max_iter):
|
|
162
|
+
z_old = _copy_arr(z)
|
|
163
|
+
|
|
164
|
+
# --- w-update ---
|
|
165
|
+
if use_cholesky:
|
|
166
|
+
# Closed-form: (XtX/n + rho*I) w = Xty/n + rho*(z - u)
|
|
167
|
+
# Use precomputed Cholesky factor for forward/back substitution
|
|
168
|
+
rhs = _neg_grad_zero + rho * (z - u)
|
|
169
|
+
if backend == "numpy":
|
|
170
|
+
from scipy.linalg import solve_triangular
|
|
171
|
+
tmp = solve_triangular(_L, rhs, lower=True)
|
|
172
|
+
w = solve_triangular(_L.T, tmp, lower=False)
|
|
173
|
+
elif backend == "cupy":
|
|
174
|
+
# Use triangular solve when available (O(n³/6) vs O(n³/3) for LU)
|
|
175
|
+
try:
|
|
176
|
+
from cupyx.scipy.linalg import solve_triangular
|
|
177
|
+
tmp = solve_triangular(_L, rhs, lower=True)
|
|
178
|
+
w = solve_triangular(_L.T, tmp, lower=False)
|
|
179
|
+
except ImportError:
|
|
180
|
+
tmp = cp.linalg.solve(_L, rhs)
|
|
181
|
+
w = cp.linalg.solve(_L.T, tmp)
|
|
182
|
+
else:
|
|
183
|
+
tmp = torch.linalg.solve_triangular(_L, rhs.unsqueeze(1), upper=False)
|
|
184
|
+
w = torch.linalg.solve_triangular(_L.T, tmp, upper=True).squeeze(1)
|
|
185
|
+
else:
|
|
186
|
+
# Nesterov-accelerated gradient descent on the w-subproblem
|
|
187
|
+
w_new = _copy_arr(w)
|
|
188
|
+
w_mom = _copy_arr(w)
|
|
189
|
+
t_mom = 1.0
|
|
190
|
+
for _ in range(cg_max_iter):
|
|
191
|
+
w_old_mom = _copy_arr(w_new)
|
|
192
|
+
g_sub = _grad_w(w_mom, z, u)
|
|
193
|
+
w_next = w_mom - lr_sub * g_sub
|
|
194
|
+
beta_mom, t_mom = _nesterov_momentum(t_mom)
|
|
195
|
+
w_mom = w_next + beta_mom * (w_next - w_new)
|
|
196
|
+
w_new = w_next
|
|
197
|
+
diff_dev = _abs_sum_dev(w_next - w_old_mom)
|
|
198
|
+
if backend != "numpy":
|
|
199
|
+
if _device_leq(diff_dev, cg_tol * n_features):
|
|
200
|
+
break
|
|
201
|
+
elif diff_dev < cg_tol * n_features:
|
|
202
|
+
break
|
|
203
|
+
w = w_new
|
|
204
|
+
|
|
205
|
+
# --- z-update: proximal operator ---
|
|
206
|
+
# Contract: proximal(z, step) = argmin_x step*P(x) + (1/2)||x - z||²
|
|
207
|
+
# ADMM z-update needs argmin_z P(z)/rho + (1/2)||z - (w+u)||²
|
|
208
|
+
# = proximal(w + u, 1/rho) with step = 1/rho
|
|
209
|
+
z = penalty.proximal(w + u, 1.0 / rho, backend=backend)
|
|
210
|
+
|
|
211
|
+
# --- u-update: dual ascent ---
|
|
212
|
+
u = u + w - z
|
|
213
|
+
|
|
214
|
+
# --- Adaptive rho + Convergence check (batched sync) ---
|
|
215
|
+
rp_dev = _norm2_dev(w - z)
|
|
216
|
+
rd_dev = _norm2_dev(z - z_old)
|
|
217
|
+
rp, rd_raw = _sync_scalars(rp_dev, rd_dev, backend=backend)
|
|
218
|
+
r_dual = rho * rd_raw
|
|
219
|
+
|
|
220
|
+
if adaptive_rho:
|
|
221
|
+
if rp > 10.0 * r_dual:
|
|
222
|
+
rho = min(rho * 2.0, 1e4)
|
|
223
|
+
elif r_dual > 10.0 * rp:
|
|
224
|
+
rho = max(rho * 0.5, 1e-4)
|
|
225
|
+
# Recompute step size to match updated rho
|
|
226
|
+
lr_sub = 1.0 / (L_f + rho + 1e-8)
|
|
227
|
+
|
|
228
|
+
if rp < tol and r_dual < tol:
|
|
229
|
+
break
|
|
230
|
+
|
|
231
|
+
# Return z (penalized/feasible variable), not w (unconstrained).
|
|
232
|
+
# At convergence w ≈ z, but z always satisfies the penalty structure.
|
|
233
|
+
n_iter = iteration + 1
|
|
234
|
+
if n_iter >= max_iter:
|
|
235
|
+
warnings.warn(
|
|
236
|
+
f"admm_solver did not converge within {max_iter} iterations "
|
|
237
|
+
f"(loss={getattr(loss, 'name', '?')}, penalty={getattr(penalty, 'name', '?')}).",
|
|
238
|
+
ConvergenceWarning,
|
|
239
|
+
stacklevel=2,
|
|
240
|
+
)
|
|
241
|
+
return z, n_iter
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Solver convergence constants and thresholds."""
|
|
2
|
+
|
|
3
|
+
_SLACK_TOLERANCE = 1e-14
|
|
4
|
+
_DIVERGE_COEF_NORM_CAP = 100.0
|
|
5
|
+
_DIVERGE_OBJ_RATIO = 100.0
|
|
6
|
+
_DIVERGE_OBJ_ABS = 10.0
|
|
7
|
+
_BB_RESTART_DOT_TOL = 1e-14
|
|
8
|
+
_LIPSCHITZ_FLOOR = 1e-30
|
|
9
|
+
_LIPSCHITZ_SAFETY_LOGISTIC_CV = 2.0
|
|
10
|
+
|
|
11
|
+
# Gradient clipping thresholds (used by fista, fista_bb, fista_lla, _array_ops)
|
|
12
|
+
# gmax = max(coef_norm * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
|
|
13
|
+
_GRAD_CLIP_COEF_FACTOR = 10.0 # scales with coefficient magnitude
|
|
14
|
+
_GRAD_CLIP_ABS_FLOOR = 1e3 # minimum gradient cap (prevents zero-cap at coef=0)
|
|
15
|
+
_GRAD_CLIP_MAX = 1e4 # absolute maximum gradient cap
|