statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,300 @@
1
+ """
2
+ Stepwise model selection for regression models.
3
+ Supports forward, backward, and bidirectional selection.
4
+ """
5
+
6
+ from typing import Optional, Union, List, Literal
7
+ import numpy as np
8
+ from copy import deepcopy
9
+ from joblib import Parallel, delayed
10
+
11
+ from statgpu.linear_model import LinearRegression, Ridge, Lasso, LogisticRegression
12
+
13
+
14
+ class StepwiseSelector:
15
+ """
16
+ Stepwise model selection using AIC or BIC criterion.
17
+
18
+ Supports forward selection, backward elimination, and bidirectional search.
19
+
20
+ Parameters
21
+ ----------
22
+ model_class : class
23
+ Model class to use (LinearRegression, Ridge, Lasso, LogisticRegression).
24
+ criterion : str, default='aic'
25
+ Criterion for model selection: 'aic' or 'bic'.
26
+ direction : str, default='both'
27
+ Direction of search: 'forward', 'backward', or 'both'.
28
+ max_features : int, optional
29
+ Maximum number of features to select.
30
+ **model_kwargs
31
+ Additional arguments passed to the model.
32
+
33
+ Attributes
34
+ ----------
35
+ selected_features_ : list
36
+ Indices of selected features.
37
+ best_model_ : object
38
+ Fitted model with selected features.
39
+ aic_history_ : list
40
+ AIC values at each step.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ model_class,
46
+ criterion: str = 'aic',
47
+ direction: Literal['forward', 'backward', 'both'] = 'both',
48
+ max_features: Optional[int] = None,
49
+ n_jobs: Optional[int] = None,
50
+ **model_kwargs
51
+ ):
52
+ self.model_class = model_class
53
+ self.criterion = criterion.lower()
54
+ self.direction = direction
55
+ self.max_features = max_features
56
+ self.n_jobs = n_jobs
57
+ self.model_kwargs = model_kwargs
58
+
59
+ if self.criterion not in ('aic', 'bic'):
60
+ raise ValueError("criterion must be 'aic' or 'bic'")
61
+
62
+ self.selected_features_ = None
63
+ self.best_model_ = None
64
+ self.aic_history_ = []
65
+ self.bic_history_ = []
66
+ self._score_cache = {}
67
+
68
+ def fit(self, X, y):
69
+ """
70
+ Fit stepwise model selection.
71
+
72
+ Parameters
73
+ ----------
74
+ X : array-like of shape (n_samples, n_features)
75
+ Training data.
76
+ y : array-like of shape (n_samples,)
77
+ Target values.
78
+
79
+ Returns
80
+ -------
81
+ self : object
82
+ """
83
+ X = np.asarray(X)
84
+ y = np.asarray(y)
85
+ n_samples, n_features = X.shape
86
+
87
+ if self.max_features is None:
88
+ self.max_features = n_features
89
+
90
+ # Initialize
91
+ self._score_cache = {}
92
+ if self.direction == 'forward':
93
+ selected = []
94
+ remaining = list(range(n_features))
95
+ elif self.direction == 'backward':
96
+ selected = list(range(n_features))
97
+ remaining = []
98
+ else: # both
99
+ selected = []
100
+ remaining = list(range(n_features))
101
+
102
+ # Fit initial model
103
+ best_score = self._fit_and_score(X, y, selected)
104
+ self.aic_history_.append(best_score['aic'])
105
+ self.bic_history_.append(best_score['bic'])
106
+
107
+ improved = True
108
+ iteration = 0
109
+
110
+ while improved and len(selected) < self.max_features:
111
+ improved = False
112
+ iteration += 1
113
+
114
+ if self.direction in ('forward', 'both'):
115
+ # Try adding each remaining feature
116
+ candidates = [(feature, selected + [feature]) for feature in remaining[:]]
117
+ scores = self._evaluate_candidates(X, y, candidates)
118
+ for feature, score in scores:
119
+ current_score = score[self.criterion]
120
+ if current_score < best_score[self.criterion]:
121
+ best_score = score
122
+ best_feature = feature
123
+ best_action = 'add'
124
+ improved = True
125
+
126
+ if self.direction in ('backward', 'both') and len(selected) > 0:
127
+ # Try removing each selected feature
128
+ candidates = [(feature, [f for f in selected if f != feature]) for feature in selected[:]]
129
+ scores = self._evaluate_candidates(X, y, candidates)
130
+ for feature, score in scores:
131
+ current_score = score[self.criterion]
132
+ if current_score < best_score[self.criterion]:
133
+ best_score = score
134
+ best_feature = feature
135
+ best_action = 'remove'
136
+ improved = True
137
+
138
+ if improved:
139
+ if best_action == 'add':
140
+ selected.append(best_feature)
141
+ remaining.remove(best_feature)
142
+ else:
143
+ selected.remove(best_feature)
144
+ remaining.append(best_feature)
145
+
146
+ self.aic_history_.append(best_score['aic'])
147
+ self.bic_history_.append(best_score['bic'])
148
+
149
+ print(f"Step {iteration}: {best_action} feature {best_feature}, "
150
+ f"{self.criterion.upper()}={best_score[self.criterion]:.2f}")
151
+
152
+ # Fit final model
153
+ self.selected_features_ = sorted(selected)
154
+ if len(selected) > 0:
155
+ self.best_model_ = self.model_class(**self.model_kwargs)
156
+ self.best_model_.fit(X[:, selected], y)
157
+
158
+ return self
159
+
160
+ def _evaluate_candidates(self, X, y, candidates):
161
+ """Evaluate feature candidates in parallel with memoized scores."""
162
+ feature_to_cache_key = {
163
+ feature: tuple(sorted(feature_indices)) for feature, feature_indices in candidates
164
+ }
165
+
166
+ def _score_for_indices(feature_indices):
167
+ key = tuple(sorted(feature_indices))
168
+ if key in self._score_cache:
169
+ return key, self._score_cache[key]
170
+
171
+ score = self._fit_and_score(X, y, feature_indices)
172
+ return key, score
173
+
174
+ def eval_one(feature, feature_indices):
175
+ key, score = _score_for_indices(feature_indices)
176
+ self._score_cache[key] = score
177
+ return feature, score
178
+
179
+ def eval_one_parallel(feature, feature_indices):
180
+ key, score = _score_for_indices(feature_indices)
181
+ return feature, key, score
182
+
183
+ if self.n_jobs == 1 or self.n_jobs is None or len(candidates) <= 1:
184
+ return [eval_one(feature, feature_indices) for feature, feature_indices in candidates]
185
+
186
+ out = Parallel(n_jobs=self.n_jobs)(
187
+ delayed(eval_one_parallel)(feature, feature_indices) for feature, feature_indices in candidates
188
+ )
189
+ for feature, key, score in out:
190
+ expected_key = feature_to_cache_key.get(feature)
191
+ if expected_key is not None and key not in self._score_cache:
192
+ self._score_cache[key] = score
193
+ return [(feature, score) for feature, _, score in out]
194
+
195
+ def _fit_and_score(self, X, y, feature_indices):
196
+ """Fit model and return AIC/BIC scores."""
197
+ key = tuple(sorted(feature_indices))
198
+ if key in self._score_cache:
199
+ return self._score_cache[key]
200
+
201
+ if len(feature_indices) == 0:
202
+ # Null model
203
+ score = {'aic': np.inf, 'bic': np.inf}
204
+ self._score_cache[key] = score
205
+ return score
206
+
207
+ model = self.model_class(**self.model_kwargs)
208
+ try:
209
+ model.fit(X[:, feature_indices], y)
210
+
211
+ if hasattr(model, 'aic') and model.aic is not None:
212
+ score = {'aic': model.aic, 'bic': model.bic}
213
+ self._score_cache[key] = score
214
+ return score
215
+ else:
216
+ # Fallback: use R²-based approximation
217
+ n = len(y)
218
+ k = len(feature_indices) + 1 # +1 for intercept
219
+ if hasattr(model, 'rsquared'):
220
+ r2 = model.rsquared
221
+ # Approximate AIC
222
+ aic = n * np.log(1 - r2 + 1e-10) + 2 * k
223
+ bic = n * np.log(1 - r2 + 1e-10) + k * np.log(n)
224
+ score = {'aic': aic, 'bic': bic}
225
+ self._score_cache[key] = score
226
+ return score
227
+ else:
228
+ score = {'aic': np.inf, 'bic': np.inf}
229
+ self._score_cache[key] = score
230
+ return score
231
+ except Exception:
232
+ score = {'aic': np.inf, 'bic': np.inf}
233
+ self._score_cache[key] = score
234
+ return score
235
+
236
+ def predict(self, X):
237
+ """Predict using the best model."""
238
+ if self.best_model_ is None:
239
+ raise RuntimeError("Model has not been fitted yet.")
240
+ X = np.asarray(X)
241
+ return self.best_model_.predict(X[:, self.selected_features_])
242
+
243
+ def score(self, X, y):
244
+ """Return R² score of the best model."""
245
+ if self.best_model_ is None:
246
+ raise RuntimeError("Model has not been fitted yet.")
247
+ X = np.asarray(X)
248
+ return self.best_model_.score(X[:, self.selected_features_], y)
249
+
250
+ def summary(self):
251
+ """Print summary of stepwise selection."""
252
+ print("=" * 60)
253
+ print("Stepwise Model Selection Summary")
254
+ print("=" * 60)
255
+ print(f"Criterion: {self.criterion.upper()}")
256
+ print(f"Direction: {self.direction}")
257
+ print(f"Selected features: {self.selected_features_}")
258
+ print(f"Number of features: {len(self.selected_features_)}")
259
+ if self.aic_history_:
260
+ print(f"Final AIC: {self.aic_history_[-1]:.2f}")
261
+ print(f"Final BIC: {self.bic_history_[-1]:.2f}")
262
+ print("=" * 60)
263
+
264
+
265
+ def stepwise_selection(
266
+ X, y,
267
+ model_class=LinearRegression,
268
+ criterion: str = 'aic',
269
+ direction: str = 'both',
270
+ **model_kwargs
271
+ ):
272
+ """
273
+ Convenience function for stepwise selection.
274
+
275
+ Parameters
276
+ ----------
277
+ X, y : array-like
278
+ Training data.
279
+ model_class : class
280
+ Model class to use.
281
+ criterion : str, default='aic'
282
+ Selection criterion.
283
+ direction : str, default='both'
284
+ Search direction.
285
+ **model_kwargs
286
+ Model parameters.
287
+
288
+ Returns
289
+ -------
290
+ selector : StepwiseSelector
291
+ Fitted selector.
292
+ """
293
+ selector = StepwiseSelector(
294
+ model_class=model_class,
295
+ criterion=criterion,
296
+ direction=direction,
297
+ **model_kwargs
298
+ )
299
+ selector.fit(X, y)
300
+ return selector
@@ -0,0 +1,81 @@
1
+ """
2
+ GLM core utilities for statgpu.
3
+
4
+ Usage:
5
+ from statgpu.glm_core import get_glm_loss, register_glm_loss
6
+
7
+ # Built-in
8
+ loss = get_glm_loss('squared_error')
9
+
10
+ # Custom
11
+ @register_glm_loss('huber')
12
+ class HuberLoss(GLMLoss):
13
+ ...
14
+ """
15
+
16
+ from ._base import (
17
+ GLMLoss,
18
+ get_glm_loss,
19
+ register_glm_loss,
20
+ list_glm_losses,
21
+ )
22
+ from ._squared import SquaredErrorLoss
23
+ from ._logistic import LogisticLoss
24
+ from ._poisson import PoissonLoss
25
+ from ._gamma import GammaLoss
26
+ from ._inverse_gaussian import InverseGaussianLoss
27
+ from ._negative_binomial import NegativeBinomialLoss
28
+ from ._tweedie import TweedieLoss
29
+ from ._family import (
30
+ GLMFamily,
31
+ Link,
32
+ Gaussian,
33
+ Binomial,
34
+ Poisson,
35
+ Gamma,
36
+ InverseGaussian,
37
+ NegativeBinomial,
38
+ Tweedie,
39
+ )
40
+ from ._irls import IRLSSolver
41
+
42
+ # Solvers: re-export from solvers/ (generic)
43
+ from statgpu.solvers import (
44
+ fista_solver,
45
+ fista_bb_solver,
46
+ fista_lla_path,
47
+ newton_solver,
48
+ lbfgs_solver,
49
+ admm_solver,
50
+ ConvergenceWarning,
51
+ )
52
+
53
+ __all__ = [
54
+ "GLMLoss",
55
+ "SquaredErrorLoss",
56
+ "LogisticLoss",
57
+ "PoissonLoss",
58
+ "GammaLoss",
59
+ "InverseGaussianLoss",
60
+ "NegativeBinomialLoss",
61
+ "TweedieLoss",
62
+ "GLMFamily",
63
+ "Link",
64
+ "Gaussian",
65
+ "Binomial",
66
+ "Poisson",
67
+ "Gamma",
68
+ "InverseGaussian",
69
+ "NegativeBinomial",
70
+ "Tweedie",
71
+ "IRLSSolver",
72
+ "fista_solver",
73
+ "fista_bb_solver",
74
+ "admm_solver",
75
+ "newton_solver",
76
+ "lbfgs_solver",
77
+ "ConvergenceWarning",
78
+ "get_glm_loss",
79
+ "register_glm_loss",
80
+ "list_glm_losses",
81
+ ]
@@ -0,0 +1,202 @@
1
+ """
2
+ Base class for GLM loss functions in statgpu.
3
+
4
+ The GLM core loss framework supports 7 families:
5
+ - Squared error (linear regression)
6
+ - Logistic loss (binary classification)
7
+ - Poisson loss (count data)
8
+ - Gamma loss (positive continuous)
9
+ - Inverse Gaussian loss (positive continuous)
10
+ - Negative Binomial loss (overdispersed count data)
11
+ - Tweedie loss (generalized GLM family)
12
+
13
+ Structured models such as Cox, panel, and time-series models should use a
14
+ future objective layer rather than this GLM-specific interface.
15
+ """
16
+
17
+ __all__ = ["GLMLoss", "get_glm_loss", "register_glm_loss", "list_glm_losses"]
18
+
19
+
20
+ from abc import ABC, abstractmethod
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+ from statgpu.backends._array_ops import _xp as _get_xp_mod
25
+ from statgpu.backends._utils import _to_float_scalar
26
+
27
+
28
+ class GLMLoss(ABC):
29
+ """GLM loss function base class.
30
+
31
+ Objective: minimize: loss(X, y, w) + penalty(w)
32
+
33
+ Subclasses implement per-sample formulas as the single source of truth.
34
+ The base class derives ``value()``, ``gradient()``, and
35
+ ``fused_value_and_gradient()`` from them automatically.
36
+
37
+ Subclass API (implement these):
38
+ - ``per_sample_value(eta, y)`` — per-sample loss ℓ(η, y)
39
+ - ``per_sample_gradient(eta, y)`` — per-sample gradient ∂ℓ/∂η
40
+ - ``_mu_from_eta(eta)`` — link inverse μ = g⁻¹(η), with clipping
41
+ """
42
+
43
+ name: str = "base"
44
+ y_type: str = "continuous"
45
+ smooth_gradient: bool = True
46
+ has_hessian: bool = False
47
+
48
+ # ── Optimization hints (solvers read these, subclasses can override) ──
49
+ _lipschitz_safety: float = 1.0 # Lipschitz safety factor
50
+ _lipschitz_safety_cv: float = 1.0 # Extra safety factor in CV mode
51
+ _lipschitz_uses_y: bool = False # Whether Lipschitz needs y-scaling
52
+ _momentum_beta_cap: Optional[float] = None # Nesterov momentum cap (None=unlimited)
53
+ _skip_momentum: bool = False # Disable momentum entirely
54
+ _has_constant_hessian: bool = False # Hessian is constant (Newton fast path)
55
+ _prefer_fista_over_bb: bool = False # Prefer FISTA over FISTA-BB for smooth penalties
56
+ _is_quadratic: bool = False # True for squared_error (XtX constant, no y-scaling)
57
+ _supports_cholesky: bool = False # True for squared_error (ADMM can use Cholesky)
58
+ _gpu_loop_excluded: bool = False # True for logistic (async GPU loop not suitable)
59
+ _conservative_momentum_with_nonsmooth: bool = False # Cap momentum when penalty is non-smooth
60
+ _inverse_gaussian: bool = False # True for inverse Gaussian (special BB handling)
61
+ _tweedie: bool = False # True for Tweedie (special BB handling)
62
+ _poisson_like: bool = False # True for Poisson (conservative momentum burn-in)
63
+ _gamma_like: bool = False # True for Gamma (adjusted BB/momentum params)
64
+
65
+ # ── Per-sample formulas (single source of truth) ──────────────────
66
+
67
+ def per_sample_value(self, eta, y):
68
+ """Per-sample loss: ℓ(η, y). Returns array of shape (n,)."""
69
+ raise NotImplementedError(f"{self.name} does not implement per_sample_value")
70
+
71
+ def per_sample_gradient(self, eta, y):
72
+ """Per-sample gradient: ∂ℓ/∂η. Returns array of shape (n,)."""
73
+ raise NotImplementedError(f"{self.name} does not implement per_sample_gradient")
74
+
75
+ def _mu_from_eta(self, eta):
76
+ """Link inverse: μ = g⁻¹(η). Override for clipping."""
77
+ return eta # default: identity link
78
+
79
+ # ── Derived methods (implemented once in base class) ──────────────
80
+
81
+ def value(self, X, y, coef, sample_weight=None) -> float:
82
+ """Loss value: (1/n) Σ ℓ(ηᵢ, yᵢ)."""
83
+ xp = _get_xp_mod(X)
84
+ eta = X @ coef
85
+ ps = self.per_sample_value(eta, y)
86
+ if sample_weight is not None:
87
+ return float(xp.dot(sample_weight, ps)) / float(sample_weight.sum())
88
+ return float(xp.sum(ps)) / X.shape[0]
89
+
90
+ def gradient(self, X, y, coef, sample_weight=None) -> np.ndarray:
91
+ """Gradient: X' ∂ℓ/∂η / n."""
92
+ xp = _get_xp_mod(X)
93
+ eta = X @ coef
94
+ resid = self.per_sample_gradient(eta, y)
95
+ if sample_weight is not None:
96
+ return X.T @ (sample_weight * resid) / float(sample_weight.sum())
97
+ return X.T @ resid / X.shape[0]
98
+
99
+ def fused_value_and_gradient(self, X, y, coef, sample_weight=None):
100
+ """Compute value and gradient in one pass (avoids redundant X @ coef).
101
+
102
+ Returns (value, gradient) tuple.
103
+ """
104
+ xp = _get_xp_mod(X)
105
+ eta = X @ coef
106
+ ps = self.per_sample_value(eta, y)
107
+ resid = self.per_sample_gradient(eta, y)
108
+ if sample_weight is not None:
109
+ sw_sum = float(sample_weight.sum())
110
+ val = float(xp.dot(sample_weight, ps)) / sw_sum
111
+ grad = X.T @ (sample_weight * resid) / sw_sum
112
+ else:
113
+ n = X.shape[0]
114
+ val = float(xp.sum(ps)) / n
115
+ grad = X.T @ resid / n
116
+ return val, grad
117
+
118
+ def hessian(self, X, y, coef) -> np.ndarray:
119
+ """Hessian matrix (for IRLS/Newton).
120
+
121
+ Raises NotImplementedError when auto solver falls back to FISTA.
122
+ """
123
+ raise NotImplementedError(
124
+ f"{self.name} does not support Hessian."
125
+ )
126
+
127
+ def lipschitz(self, X, coef, y=None) -> float:
128
+ """Lipschitz constant (for FISTA step size step=1/L)."""
129
+ from statgpu.backends._array_ops import _max_eigval_power
130
+ XtX = X.T @ X
131
+ return _max_eigval_power(XtX) / X.shape[0]
132
+
133
+ def preprocess(self, X, y):
134
+ """Preprocess y. Default returns as-is."""
135
+ return X, y
136
+
137
+ def predict(self, X, coef):
138
+ """Map from X @ coef to prediction. Default X @ coef."""
139
+ return X @ coef
140
+
141
+
142
+ # ─── Registry ──────────────────────────────────────────────────────────────
143
+
144
+ _GLM_LOSS_REGISTRY: dict = {}
145
+
146
+
147
+ def get_glm_loss(name: str, **kwargs) -> GLMLoss:
148
+ """
149
+ Get a GLM loss by name from the registry.
150
+
151
+ Parameters
152
+ ----------
153
+ name : str
154
+ GLM loss name: 'squared_error', 'logistic', 'poisson', etc.
155
+ **kwargs
156
+ Arguments passed to the loss constructor.
157
+
158
+ Returns
159
+ -------
160
+ GLMLoss
161
+ Instantiated GLM loss object.
162
+
163
+ Raises
164
+ ------
165
+ ValueError
166
+ If loss name is not in the registry.
167
+ """
168
+ if name not in _GLM_LOSS_REGISTRY:
169
+ available = list(_GLM_LOSS_REGISTRY.keys())
170
+ raise ValueError(
171
+ f"Unknown GLM loss: {name}. Available GLM losses: {available}"
172
+ )
173
+ return _GLM_LOSS_REGISTRY[name](**kwargs)
174
+
175
+
176
+ def register_glm_loss(name: str):
177
+ """
178
+ Decorator to register a custom GLM loss class.
179
+
180
+ Parameters
181
+ ----------
182
+ name : str
183
+ Name to register the GLM loss under.
184
+
185
+ Returns
186
+ -------
187
+ callable
188
+ Decorator function that registers the loss class.
189
+ """
190
+ def decorator(cls):
191
+ if not issubclass(cls, GLMLoss):
192
+ raise TypeError(
193
+ f"GLM loss class must inherit from GLMLoss, got {cls.__bases__}"
194
+ )
195
+ _GLM_LOSS_REGISTRY[name] = cls
196
+ return cls
197
+ return decorator
198
+
199
+
200
+ def list_glm_losses() -> list:
201
+ """List all registered GLM loss names."""
202
+ return list(_GLM_LOSS_REGISTRY.keys())