statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,401 @@
1
+ """
2
+ Generalized Additive Model (GAM) with GPU support.
3
+
4
+ Implements GAM using penalized B-splines with automatic smoothing
5
+ parameter selection via Generalized Cross-Validation (GCV).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ __all__ = ["GAM"]
11
+
12
+ import numpy as np
13
+ from typing import Optional, Union
14
+
15
+ from statgpu._base import BaseEstimator
16
+ from statgpu._config import Device
17
+ from statgpu.backends import _torch_dev, _to_numpy, xp_zeros, xp_ones, xp_asarray, xp_copy
18
+ from statgpu.nonparametric.splines._bspline_basis import bspline_basis
19
+ from statgpu.nonparametric.splines._penalized import (
20
+ difference_penalty,
21
+ penalized_ls,
22
+ generalized_cross_validation,
23
+ select_lambda_gcv,
24
+ )
25
+
26
+
27
+ class GAM(BaseEstimator):
28
+ """
29
+ Generalized Additive Model (GAM) using penalized B-splines.
30
+
31
+ Fits a smooth function for each feature using B-spline basis with
32
+ a difference penalty for smoothness. Smoothing parameters can be
33
+ specified or automatically selected via GCV.
34
+
35
+ The model is: y = alpha + sum_j f_j(x_j) + epsilon
36
+
37
+ where each f_j is represented as a penalized B-spline.
38
+
39
+ Parameters
40
+ ----------
41
+ n_splines : int, default=20
42
+ Number of basis functions per feature (before penalty reduction).
43
+ degree : int, default=3
44
+ Degree of B-spline basis (3 = cubic).
45
+ lam : float or None, default=None
46
+ Smoothing parameter. If None, automatically selected via GCV.
47
+ penalty_order : int, default=2
48
+ Order of difference penalty (2 = second differences).
49
+ device : str or Device, default='auto'
50
+ Computation device: 'cpu', 'cuda', or 'auto'.
51
+ n_jobs : int or None, default=None
52
+ Number of parallel jobs.
53
+
54
+ Attributes
55
+ ----------
56
+ coef_ : array, shape (n_features * n_splines + 1,)
57
+ Fitted coefficients (intercept + spline coefficients for each feature).
58
+ intercept_ : float
59
+ Intercept term.
60
+ edf_ : float
61
+ Total effective degrees of freedom.
62
+ gcv_score_ : float
63
+ GCV score (if lam was auto-selected).
64
+ lam_ : float
65
+ Smoothing parameter used (after auto-selection if applicable).
66
+ knots_ : list of arrays
67
+ Interior knots for each feature.
68
+ n_features_ : int
69
+ Number of features in training data.
70
+
71
+ Examples
72
+ --------
73
+ >>> import numpy as np
74
+ >>> from statgpu.semiparametric import GAM
75
+ >>> X = np.random.randn(100, 3)
76
+ >>> y = np.sin(X[:, 0]) + 0.5 * X[:, 1] ** 2 + np.random.randn(100) * 0.1
77
+ >>> gam = GAM(n_splines=15, lam=1.0)
78
+ >>> gam.fit(X, y)
79
+ >>> y_pred = gam.predict(X)
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ n_splines: int = 20,
85
+ degree: int = 3,
86
+ lam: Optional[float] = None,
87
+ penalty_order: int = 2,
88
+ device: Union[str, Device] = Device.AUTO,
89
+ n_jobs: Optional[int] = None,
90
+ ):
91
+ super().__init__(device=device, n_jobs=n_jobs)
92
+ self.n_splines = n_splines
93
+ self.degree = degree
94
+ self.lam = lam
95
+ self.penalty_order = penalty_order
96
+
97
+ # Fitted attributes
98
+ self.coef_ = None
99
+ self.intercept_ = None
100
+ self.edf_ = None
101
+ self.gcv_score_ = None
102
+ self.lam_ = None
103
+ self.knots_ = None
104
+ self.n_features_ = None
105
+
106
+ def _get_xp(self):
107
+ """Get the array module for computation.
108
+
109
+ Returns ``backend.xp`` (the raw array module) so callers can use
110
+ ``xp.asarray`` etc. directly. Delegates to the parent's
111
+ ``_get_backend()`` which returns a ``BackendBase`` with correct
112
+ device/dtype handling.
113
+ """
114
+ backend = super()._get_backend(backend='auto')
115
+ return backend.xp
116
+
117
+ def _create_knots(self, x_col, n_splines, xp):
118
+ """
119
+ Create interior knots for a feature using quantiles.
120
+
121
+ Parameters
122
+ ----------
123
+ x_col : array, shape (n,)
124
+ Feature values.
125
+ n_splines : int
126
+ Number of basis functions.
127
+ xp : module
128
+ Array module.
129
+
130
+ Returns
131
+ -------
132
+ knots : array, shape (n_splines - degree - 1,)
133
+ Interior knots.
134
+ """
135
+ # Use quantiles for knot placement
136
+ # Exclude boundary knots (they'll be added by bspline_basis)
137
+ n_interior = n_splines - self.degree - 1
138
+
139
+ if n_interior <= 0:
140
+ raise ValueError(
141
+ f"n_splines ({n_splines}) must be greater than degree ({self.degree}) + 1"
142
+ )
143
+
144
+ # Use percentiles from 0 to 100, excluding boundaries
145
+ percentiles = np.linspace(0, 100, n_interior + 2)[1:-1]
146
+
147
+ # Convert to numpy for percentile computation
148
+ x_np = _to_numpy(x_col)
149
+
150
+ knots = np.percentile(x_np, percentiles)
151
+
152
+ # Remove duplicate knots (can happen with discrete data)
153
+ knots = np.unique(knots)
154
+
155
+ return xp_asarray(knots, dtype=xp.float64, xp=xp, ref_arr=x_col)
156
+
157
+ def _build_basis(self, X, xp):
158
+ """
159
+ Build combined basis matrix for all features.
160
+
161
+ Parameters
162
+ ----------
163
+ X : array, shape (n, p)
164
+ Input features.
165
+ xp : module
166
+ Array module.
167
+
168
+ Returns
169
+ -------
170
+ B : array, shape (n, 1 + sum(n_basis_j))
171
+ Combined basis matrix with intercept column.
172
+ penalty : array, shape (1 + sum(n_basis_j), 1 + sum(n_basis_j))
173
+ Block-diagonal penalty matrix (intercept not penalized).
174
+ """
175
+ n, p = X.shape
176
+ basis_blocks = []
177
+ penalty_blocks = []
178
+ total_basis = 0
179
+
180
+ for j in range(p):
181
+ x_col = X[:, j]
182
+
183
+ # Create knots for this feature
184
+ knots_j = self._create_knots(x_col, self.n_splines, xp)
185
+ self.knots_.append(knots_j)
186
+
187
+ # Store training boundary for prediction
188
+ self._boundary_lo_.append(float(xp.min(x_col)))
189
+ self._boundary_hi_.append(float(xp.max(x_col)))
190
+
191
+ # Build B-spline basis
192
+ B_j = bspline_basis(x_col, knots_j, degree=self.degree, xp=xp)
193
+ n_basis_j = B_j.shape[1]
194
+
195
+ # Build penalty matrix
196
+ S_j = difference_penalty(self.penalty_order, n_basis_j, xp)
197
+
198
+ basis_blocks.append(B_j)
199
+ penalty_blocks.append(S_j)
200
+ total_basis += n_basis_j
201
+
202
+ # Combine basis matrices: [1, B_1, B_2, ..., B_p]
203
+ intercept_col = xp_ones((n, 1), xp.float64, xp, X)
204
+ B = xp.hstack([intercept_col] + basis_blocks)
205
+
206
+ # Block-diagonal penalty with intercept dimension (not penalized)
207
+ # Size: (1 + total_basis, 1 + total_basis) to match B
208
+ full_size = 1 + total_basis
209
+ penalty = xp_zeros((full_size, full_size), xp.float64, xp, X)
210
+ offset = 1 # Skip intercept (row/col 0)
211
+ for S_j in penalty_blocks:
212
+ n_j = S_j.shape[0]
213
+ penalty[offset:offset + n_j, offset:offset + n_j] = S_j
214
+ offset += n_j
215
+
216
+ return B, penalty
217
+
218
+ def fit(self, X, y=None, **fit_params):
219
+ """
220
+ Fit the GAM model.
221
+
222
+ Parameters
223
+ ----------
224
+ X : array-like, shape (n_samples, n_features)
225
+ Training data.
226
+ y : array-like, shape (n_samples,)
227
+ Target values.
228
+
229
+ Returns
230
+ -------
231
+ self : GAM
232
+ Fitted model.
233
+ """
234
+ xp = self._get_xp()
235
+
236
+ # Convert to arrays on the correct device
237
+ # For torch backend, ensure arrays land on CUDA (not CPU)
238
+ _ref = None
239
+ if xp.__name__ == "torch":
240
+ import torch
241
+ _dev = getattr(self, 'device', None)
242
+ if _dev is not None and hasattr(_dev, 'value') and _dev.value in ('cuda', 'torch'):
243
+ _ref = torch.empty(0, device="cuda")
244
+ elif torch.cuda.is_available():
245
+ _ref = torch.empty(0, device="cuda")
246
+ X = xp_asarray(X, dtype=xp.float64, xp=xp, ref_arr=_ref)
247
+ y = xp_asarray(y, dtype=xp.float64, xp=xp, ref_arr=X).ravel()
248
+
249
+ n, p = X.shape
250
+ self.n_features_ = p
251
+ self.knots_ = []
252
+ self._boundary_lo_ = []
253
+ self._boundary_hi_ = []
254
+
255
+ # Build basis matrix and penalty
256
+ B, penalty = self._build_basis(X, xp)
257
+
258
+ # Center spline basis columns (not intercept) so the intercept
259
+ # captures the overall mean of y. This makes the intercept
260
+ # identifiable even though spline basis can represent constants.
261
+ self._basis_mean_ = xp.mean(B[:, 1:], axis=0)
262
+ B_centered = xp_copy(B)
263
+ B_centered[:, 1:] = B[:, 1:] - self._basis_mean_
264
+
265
+ # Select smoothing parameter
266
+ if self.lam is None:
267
+ # Auto-select via GCV
268
+ best_lam, gcv_scores = select_lambda_gcv(
269
+ B_centered, y, penalty, xp=xp
270
+ )
271
+ self.lam_ = best_lam
272
+ self.gcv_score_ = float(xp.min(gcv_scores))
273
+ else:
274
+ self.lam_ = self.lam
275
+ self.gcv_score_ = None
276
+
277
+ # Fit the model with centered basis
278
+ beta, edf = penalized_ls(B_centered, y, penalty, self.lam_, xp)
279
+
280
+ # Store results
281
+ self.coef_ = beta
282
+ self.intercept_ = float(beta[0])
283
+ self.edf_ = float(edf) if not isinstance(edf, float) else edf
284
+ self._fitted = True
285
+
286
+ # Store training data info for prediction
287
+ self._xp = xp
288
+ self._xp_asarray_ref_ = X # device reference for xp_asarray
289
+
290
+ return self
291
+
292
+ def predict(self, X):
293
+ """
294
+ Predict using the fitted GAM model.
295
+
296
+ Parameters
297
+ ----------
298
+ X : array-like, shape (n_samples, n_features)
299
+ Input features.
300
+
301
+ Returns
302
+ -------
303
+ y_pred : array, shape (n_samples,)
304
+ Predicted values.
305
+ """
306
+ self._check_is_fitted()
307
+
308
+ # Re-resolve backend to handle device changes since fit()
309
+ xp = self._get_xp()
310
+ X = xp_asarray(X, dtype=xp.float64, xp=xp, ref_arr=self._xp_asarray_ref_)
311
+
312
+ n, p = X.shape
313
+ if p != self.n_features_:
314
+ raise ValueError(
315
+ f"X has {p} features, but model was fitted with {self.n_features_}"
316
+ )
317
+
318
+ # Build basis for prediction (use training boundaries to avoid
319
+ # "knots must be strictly within boundary" errors on small batches)
320
+ basis_blocks = []
321
+ for j in range(p):
322
+ x_col = X[:, j]
323
+ knots_j = self.knots_[j]
324
+ B_j = bspline_basis(
325
+ x_col, knots_j, degree=self.degree, xp=xp,
326
+ boundary_lo=self._boundary_lo_[j],
327
+ boundary_hi=self._boundary_hi_[j],
328
+ )
329
+ basis_blocks.append(B_j)
330
+
331
+ # Combine: [1, B_1, B_2, ..., B_p]
332
+ intercept_col = xp_ones((n, 1), xp.float64, xp, X)
333
+ B = xp.hstack([intercept_col] + basis_blocks)
334
+
335
+ # Apply same centering as in fit
336
+ B[:, 1:] = B[:, 1:] - self._basis_mean_
337
+
338
+ # Predict
339
+ y_pred = B @ self.coef_
340
+
341
+ return _to_numpy(y_pred)
342
+
343
+ def summary(self):
344
+ """
345
+ Print a summary of the fitted GAM model.
346
+
347
+ Returns
348
+ -------
349
+ summary_dict : dict
350
+ Dictionary containing model summary information.
351
+ """
352
+ self._check_is_fitted()
353
+
354
+ summary_dict = {
355
+ 'n_features': self.n_features_,
356
+ 'n_splines_per_feature': self.n_splines,
357
+ 'spline_degree': self.degree,
358
+ 'penalty_order': self.penalty_order,
359
+ 'smoothing_parameter': self.lam_,
360
+ 'effective_df': self.edf_,
361
+ 'intercept': self.intercept_,
362
+ }
363
+
364
+ if self.gcv_score_ is not None:
365
+ summary_dict['gcv_score'] = self.gcv_score_
366
+
367
+ print("=" * 50)
368
+ print("GAM Model Summary")
369
+ print("=" * 50)
370
+ print(f"Number of features: {self.n_features_}")
371
+ print(f"B-splines per feature: {self.n_splines}")
372
+ print(f"Spline degree: {self.degree}")
373
+ print(f"Penalty order: {self.penalty_order}")
374
+ print(f"Smoothing parameter (lambda): {self.lam_:.6g}")
375
+ print(f"Effective degrees of freedom: {self.edf_:.2f}")
376
+ print(f"Intercept: {self.intercept_:.6f}")
377
+ if self.gcv_score_ is not None:
378
+ print(f"GCV score: {self.gcv_score_:.6f}")
379
+ print("=" * 50)
380
+
381
+ return summary_dict
382
+
383
+ def get_params(self, deep=True):
384
+ """Get parameters for this estimator."""
385
+ params = super().get_params(deep)
386
+ params.update({
387
+ 'n_splines': self.n_splines,
388
+ 'degree': self.degree,
389
+ 'lam': self.lam,
390
+ 'penalty_order': self.penalty_order,
391
+ })
392
+ return params
393
+
394
+ def set_params(self, **params):
395
+ """Set parameters for this estimator."""
396
+ for key, value in params.items():
397
+ if key in ('n_splines', 'degree', 'lam', 'penalty_order'):
398
+ setattr(self, key, value)
399
+ else:
400
+ super().set_params(**{key: value})
401
+ return self
@@ -0,0 +1,24 @@
1
+ """Generic optimization solvers for penalized loss functions.
2
+
3
+ These solvers work with any loss that implements the GLMLoss interface
4
+ (value, gradient, fused_value_and_gradient, lipschitz, hessian, preprocess)
5
+ and any penalty with a proximal operator.
6
+ """
7
+
8
+ __all__ = [
9
+ "fista_solver",
10
+ "fista_bb_solver",
11
+ "fista_lla_path",
12
+ "newton_solver",
13
+ "lbfgs_solver",
14
+ "admm_solver",
15
+ "ConvergenceWarning",
16
+ ]
17
+
18
+ from ._convergence import ConvergenceWarning
19
+ from ._fista import fista_solver
20
+ from ._fista_bb import fista_bb_solver
21
+ from ._fista_lla import fista_lla_path
22
+ from ._newton import newton_solver
23
+ from ._lbfgs import lbfgs_solver
24
+ from ._admm import admm_solver
@@ -0,0 +1,241 @@
1
+ """ADMM solver for penalized GLM optimization.
2
+
3
+ Reformulates min_w f(Xw; y) + p(w) as a consensus ADMM problem and solves
4
+ via alternating direction method of multipliers. The w-update (smooth
5
+ subproblem) uses either a direct Cholesky solve (for squared-error loss with
6
+ moderate dimensionality) or Nesterov-accelerated gradient descent. The z-update
7
+ reuses the penalty proximal operator and is element-wise / GPU-friendly.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import warnings
13
+
14
+ import numpy as np
15
+
16
+ from statgpu.backends import _resolve_backend
17
+ from statgpu.backends._array_ops import (
18
+ _abs_sum_dev,
19
+ _copy_arr,
20
+ _device_leq,
21
+ _norm2_dev,
22
+ _sync_scalars,
23
+ _zeros,
24
+ _zeros_like,
25
+ )
26
+ from ._convergence import ConvergenceWarning
27
+ from ._utils import (
28
+ _nesterov_momentum,
29
+ _validate_uniform_sample_weight,
30
+ )
31
+
32
+ __all__ = ["admm_solver"]
33
+
34
+
35
+ def admm_solver(
36
+ loss: "GLMLoss",
37
+ penalty: "Penalty | None",
38
+ X,
39
+ y,
40
+ max_iter: int = 200,
41
+ tol: float = 1e-4,
42
+ rho: float = 1.0,
43
+ adaptive_rho: bool = True,
44
+ cg_max_iter: int = 30,
45
+ cg_tol: float = 1e-6,
46
+ init_coef=None,
47
+ sample_weight=None,
48
+ ) -> tuple:
49
+ """ADMM solver for penalized GLM optimization.
50
+
51
+ Reformulates min_w f(Xw; y) + p(w) as:
52
+ min_{w,z} f(Xw; y) + p(z) s.t. w = z
53
+
54
+ and solves via the alternating direction method of multipliers:
55
+ w^{k+1} = argmin_w f(Xw; y) + (rho/2)||w - z^k + u^k||^2
56
+ z^{k+1} = prox_{p/rho}(w^{k+1} + u^k)
57
+ u^{k+1} = u^k + w^{k+1} - z^{k+1}
58
+
59
+ The w-update is a smooth, strongly convex problem solved via conjugate
60
+ gradient. The z-update reuses penalty.proximal(). Both are GPU-friendly:
61
+ w-update uses dense matmuls (cuBLAS), z-update is element-wise.
62
+
63
+ Supports numpy / cupy / torch backends via auto-detection of X.
64
+
65
+ Parameters
66
+ ----------
67
+ loss : GLMLoss
68
+ penalty : Penalty
69
+ X, y : arrays
70
+ max_iter : int
71
+ Maximum ADMM outer iterations.
72
+ tol : float
73
+ Convergence tolerance for primal/dual residuals.
74
+ rho : float
75
+ Augmented Lagrangian penalty parameter.
76
+ adaptive_rho : bool
77
+ Adapt rho based on primal/dual residual balance.
78
+ cg_max_iter : int
79
+ Maximum CG iterations for w-update subproblem.
80
+ cg_tol : float
81
+ CG convergence tolerance.
82
+ init_coef : array, optional
83
+ Initial coefficients.
84
+ sample_weight : array, optional
85
+
86
+ Returns
87
+ -------
88
+ coef : array, n_iter : int
89
+ """
90
+ backend = _resolve_backend("auto", X)
91
+ X_proc, y_proc = loss.preprocess(X, y)
92
+ n_features = X_proc.shape[1]
93
+
94
+ # Initialize
95
+ if init_coef is not None:
96
+ w = (
97
+ _copy_arr(init_coef)
98
+ if hasattr(init_coef, "copy") or hasattr(init_coef, "clone")
99
+ else np.array(init_coef).copy()
100
+ )
101
+ else:
102
+ w = _zeros(n_features, backend, ref_tensor=X)
103
+
104
+ z = _copy_arr(w)
105
+ u = _zeros_like(w)
106
+
107
+ if sample_weight is not None:
108
+ _validate_uniform_sample_weight(sample_weight, X_proc.shape[0], "admm_solver")
109
+
110
+ def _grad_w(w_vec, z_cur, u_cur):
111
+ """Gradient of f(w) + (rho/2)||w - z_cur + u_cur||^2 w.r.t. w."""
112
+ g = loss.gradient(X_proc, y_proc, w_vec, sample_weight=sample_weight)
113
+ g = g + rho * (w_vec - z_cur + u_cur)
114
+ return g
115
+
116
+ # Detect if loss supports Cholesky (constant Hessian, e.g. squared_error).
117
+ # For GLM losses, use Nesterov-accelerated gradient descent.
118
+ # When using Cholesky we pin rho (disable adaptive_rho) because the
119
+ # precomputed _A_mat = XtX/n + rho*I would become stale if rho changed.
120
+ use_cholesky = getattr(loss, '_supports_cholesky', False) and n_features <= 2000
121
+ if use_cholesky:
122
+ adaptive_rho = False
123
+
124
+ if use_cholesky:
125
+ _hess_const = loss.hessian(X_proc, y_proc, w) # XtX / n
126
+ _A_mat = _hess_const
127
+ _cholesky_ok = False
128
+ if hasattr(_hess_const, 'shape'):
129
+ try:
130
+ if backend == "numpy":
131
+ _A_mat = _hess_const + rho * np.eye(n_features, dtype=_hess_const.dtype)
132
+ _L = np.linalg.cholesky(_A_mat)
133
+ elif backend == "cupy":
134
+ import cupy as cp
135
+ _A_mat = _hess_const + rho * cp.eye(n_features, dtype=_hess_const.dtype)
136
+ _L = cp.linalg.cholesky(_A_mat)
137
+ else:
138
+ import torch
139
+ _A_mat = _hess_const + rho * torch.eye(n_features, dtype=_hess_const.dtype, device=_hess_const.device)
140
+ _L = torch.linalg.cholesky(_A_mat)
141
+ _cholesky_ok = True
142
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
143
+ # Matrix not positive-definite (numerical issues, collinear features)
144
+ # Fall back to CG solver below
145
+ _cholesky_ok = False
146
+ if not _cholesky_ok:
147
+ use_cholesky = False
148
+
149
+ # Precompute -grad_f(0) = Xty/n for squared_error (the constant part)
150
+ _zero_coef = _zeros_like(w)
151
+ _neg_grad_zero = -loss.gradient(X_proc, y_proc, _zero_coef, sample_weight=sample_weight) # Xty/n
152
+
153
+ else:
154
+ # Gradient descent step: 1/(L_f + rho)
155
+ L_f = loss.lipschitz(X_proc, w, y=y_proc)
156
+ if L_f <= 0:
157
+ L_f = 1.0
158
+ lr_sub = 1.0 / (L_f + rho + 1e-8)
159
+ iteration = -1 # default if max_iter=0
160
+
161
+ for iteration in range(max_iter):
162
+ z_old = _copy_arr(z)
163
+
164
+ # --- w-update ---
165
+ if use_cholesky:
166
+ # Closed-form: (XtX/n + rho*I) w = Xty/n + rho*(z - u)
167
+ # Use precomputed Cholesky factor for forward/back substitution
168
+ rhs = _neg_grad_zero + rho * (z - u)
169
+ if backend == "numpy":
170
+ from scipy.linalg import solve_triangular
171
+ tmp = solve_triangular(_L, rhs, lower=True)
172
+ w = solve_triangular(_L.T, tmp, lower=False)
173
+ elif backend == "cupy":
174
+ # Use triangular solve when available (O(n³/6) vs O(n³/3) for LU)
175
+ try:
176
+ from cupyx.scipy.linalg import solve_triangular
177
+ tmp = solve_triangular(_L, rhs, lower=True)
178
+ w = solve_triangular(_L.T, tmp, lower=False)
179
+ except ImportError:
180
+ tmp = cp.linalg.solve(_L, rhs)
181
+ w = cp.linalg.solve(_L.T, tmp)
182
+ else:
183
+ tmp = torch.linalg.solve_triangular(_L, rhs.unsqueeze(1), upper=False)
184
+ w = torch.linalg.solve_triangular(_L.T, tmp, upper=True).squeeze(1)
185
+ else:
186
+ # Nesterov-accelerated gradient descent on the w-subproblem
187
+ w_new = _copy_arr(w)
188
+ w_mom = _copy_arr(w)
189
+ t_mom = 1.0
190
+ for _ in range(cg_max_iter):
191
+ w_old_mom = _copy_arr(w_new)
192
+ g_sub = _grad_w(w_mom, z, u)
193
+ w_next = w_mom - lr_sub * g_sub
194
+ beta_mom, t_mom = _nesterov_momentum(t_mom)
195
+ w_mom = w_next + beta_mom * (w_next - w_new)
196
+ w_new = w_next
197
+ diff_dev = _abs_sum_dev(w_next - w_old_mom)
198
+ if backend != "numpy":
199
+ if _device_leq(diff_dev, cg_tol * n_features):
200
+ break
201
+ elif diff_dev < cg_tol * n_features:
202
+ break
203
+ w = w_new
204
+
205
+ # --- z-update: proximal operator ---
206
+ # Contract: proximal(z, step) = argmin_x step*P(x) + (1/2)||x - z||²
207
+ # ADMM z-update needs argmin_z P(z)/rho + (1/2)||z - (w+u)||²
208
+ # = proximal(w + u, 1/rho) with step = 1/rho
209
+ z = penalty.proximal(w + u, 1.0 / rho, backend=backend)
210
+
211
+ # --- u-update: dual ascent ---
212
+ u = u + w - z
213
+
214
+ # --- Adaptive rho + Convergence check (batched sync) ---
215
+ rp_dev = _norm2_dev(w - z)
216
+ rd_dev = _norm2_dev(z - z_old)
217
+ rp, rd_raw = _sync_scalars(rp_dev, rd_dev, backend=backend)
218
+ r_dual = rho * rd_raw
219
+
220
+ if adaptive_rho:
221
+ if rp > 10.0 * r_dual:
222
+ rho = min(rho * 2.0, 1e4)
223
+ elif r_dual > 10.0 * rp:
224
+ rho = max(rho * 0.5, 1e-4)
225
+ # Recompute step size to match updated rho
226
+ lr_sub = 1.0 / (L_f + rho + 1e-8)
227
+
228
+ if rp < tol and r_dual < tol:
229
+ break
230
+
231
+ # Return z (penalized/feasible variable), not w (unconstrained).
232
+ # At convergence w ≈ z, but z always satisfies the penalty structure.
233
+ n_iter = iteration + 1
234
+ if n_iter >= max_iter:
235
+ warnings.warn(
236
+ f"admm_solver did not converge within {max_iter} iterations "
237
+ f"(loss={getattr(loss, 'name', '?')}, penalty={getattr(penalty, 'name', '?')}).",
238
+ ConvergenceWarning,
239
+ stacklevel=2,
240
+ )
241
+ return z, n_iter
@@ -0,0 +1,15 @@
1
+ """Solver convergence constants and thresholds."""
2
+
3
+ _SLACK_TOLERANCE = 1e-14
4
+ _DIVERGE_COEF_NORM_CAP = 100.0
5
+ _DIVERGE_OBJ_RATIO = 100.0
6
+ _DIVERGE_OBJ_ABS = 10.0
7
+ _BB_RESTART_DOT_TOL = 1e-14
8
+ _LIPSCHITZ_FLOOR = 1e-30
9
+ _LIPSCHITZ_SAFETY_LOGISTIC_CV = 2.0
10
+
11
+ # Gradient clipping thresholds (used by fista, fista_bb, fista_lla, _array_ops)
12
+ # gmax = max(coef_norm * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
13
+ _GRAD_CLIP_COEF_FACTOR = 10.0 # scales with coefficient magnitude
14
+ _GRAD_CLIP_ABS_FLOOR = 1e3 # minimum gradient cap (prevents zero-cap at coef=0)
15
+ _GRAD_CLIP_MAX = 1e4 # absolute maximum gradient cap
@@ -0,0 +1,6 @@
1
+ """Convergence warning for solvers."""
2
+
3
+
4
+ class ConvergenceWarning(UserWarning):
5
+ """Solver did not converge within the iteration limit."""
6
+ pass