statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,104 @@
1
+ """Legacy solver methods from _solver.py.
2
+
3
+ DO NOT import in production code. Kept for reference only.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ from statgpu.backends import _to_numpy
10
+ from statgpu.backends._utils import _to_float_scalar
11
+ from statgpu.backends._array_ops import _abs_sum_dev, _clip_grad_on_device, _copy_arr
12
+
13
+ def fista_sqerr_adaptive_l1_fused(
14
+ X, y, penalty_weights, alpha,
15
+ XtX, Xty, yty, n_samples,
16
+ L_init, max_iter, tol,
17
+ backend, no_momentum=False,
18
+ ):
19
+ """Fused FISTA for squared_error + AdaptiveL1 with pre-computed XtX/Xty.
20
+
21
+ Eliminates:
22
+ - Redundant X@coef matmul (uses XtX instead)
23
+ - GPU→CPU syncs (convergence check deferred)
24
+ - Element-wise kernel overhead (fused update+proximal+momentum)
25
+
26
+ Parameters
27
+ ----------
28
+ X, y : array (centered)
29
+ penalty_weights : array (p,) — LLA weights
30
+ alpha : float — penalty alpha
31
+ XtX, Xty, yty : pre-computed
32
+ n_samples : int
33
+ L_init : float — initial Lipschitz
34
+ max_iter, tol : FISTA params
35
+ backend : 'torch' or 'cupy'
36
+ no_momentum : bool
37
+
38
+ Returns
39
+ -------
40
+ coef : array (p,)
41
+ n_iter : int
42
+ """
43
+ p = XtX.shape[0]
44
+ step = 1.0 / L_init
45
+ L = L_init
46
+
47
+ if backend == "torch":
48
+ import torch
49
+ thresh = torch.tensor(
50
+ alpha * penalty_weights * step,
51
+ device=XtX.device, dtype=XtX.dtype,
52
+ )
53
+ coef = torch.zeros(p, device=XtX.device, dtype=XtX.dtype)
54
+ coef_old = coef.clone()
55
+ y_k = coef.clone()
56
+ _fused = _get_sqerr_proximal_torch()
57
+ # Pre-allocate for momentum-free case
58
+ _zero_beta = 0.0
59
+ else:
60
+ import cupy as cp
61
+ thresh = cp.asarray(alpha * penalty_weights * step, dtype=cp.float64)
62
+ coef = cp.zeros(p, dtype=cp.float64)
63
+ coef_old = coef.copy()
64
+ y_k = coef.copy()
65
+ _fused = _get_sqerr_proximal_cupy()
66
+ _zero_beta = 0.0
67
+
68
+ t_k = 1.0
69
+ _sync_interval = 10 # Only check convergence every N iterations
70
+
71
+ iteration = -1 # default if max_iter=0
72
+ for iteration in range(max_iter):
73
+ # Gradient: grad = (XtX @ y_k - Xty) / n
74
+ grad = (XtX @ y_k - Xty) / n_samples
75
+
76
+ # Clip gradients (avoid sync — do it on GPU)
77
+ if iteration % 10 == 0:
78
+ grad = _clip_grad_on_device(grad, coef_old, backend)
79
+
80
+ # Proximal gradient step (no backtracking — Lipschitz is exact for squared_error)
81
+ # Pre-compute momentum coefficient so the fused kernel can apply it in one pass.
82
+ if no_momentum:
83
+ beta_mom = 0.0
84
+ else:
85
+ t_new = (1.0 + np.sqrt(1.0 + 4.0 * t_k * t_k)) / 2.0
86
+ beta_mom = (t_k - 1.0) / t_new
87
+ coef_new, y_k = _fused(y_k, grad, step, thresh, coef_old, beta_mom)
88
+ coef = coef_new
89
+
90
+ # Momentum state update
91
+ if not no_momentum:
92
+ t_k = t_new
93
+
94
+ # Convergence check (device-side, minimal sync)
95
+ if iteration < 20 or iteration % _sync_interval == 0:
96
+ coef_diff_dev = _abs_sum_dev(coef - coef_old)
97
+ if _to_float_scalar(coef_diff_dev) < tol:
98
+ break
99
+
100
+ coef_old = _copy_arr(coef)
101
+
102
+ return _to_numpy(coef), iteration + 1
103
+
104
+
@@ -0,0 +1,25 @@
1
+ """Penalized GLM models (split via mixin pattern)."""
2
+
3
+ from ._base import PenalizedGeneralizedLinearModel, SelectivePenalty
4
+ from ._penalized_linear import PenalizedLinearRegression
5
+ from ._penalized_logistic import PenalizedLogisticRegression
6
+ from ._penalized_poisson import PenalizedPoissonRegression
7
+ from ._penalized_gamma import PenalizedGammaRegression
8
+ from ._penalized_inverse_gaussian import PenalizedInverseGaussianRegression
9
+ from ._penalized_negative_binomial import PenalizedNegativeBinomialRegression
10
+ from ._penalized_tweedie import PenalizedTweedieRegression
11
+ from ._penalized_cv import PenalizedGLM_CV, ApproximateCVWarning
12
+
13
+ __all__ = [
14
+ "PenalizedGeneralizedLinearModel",
15
+ "SelectivePenalty",
16
+ "PenalizedLinearRegression",
17
+ "PenalizedLogisticRegression",
18
+ "PenalizedPoissonRegression",
19
+ "PenalizedGammaRegression",
20
+ "PenalizedInverseGaussianRegression",
21
+ "PenalizedNegativeBinomialRegression",
22
+ "PenalizedTweedieRegression",
23
+ "PenalizedGLM_CV",
24
+ "ApproximateCVWarning",
25
+ ]
@@ -0,0 +1,437 @@
1
+ """Core PenalizedGeneralizedLinearModel class and SelectivePenalty.
2
+
3
+ This module contains the class definition, __init__, and core utility methods.
4
+ Fit, inference, and predict methods live in separate mixin modules.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ __all__ = ["PenalizedGeneralizedLinearModel", "SelectivePenalty"]
10
+
11
+ from typing import Optional, Union, Dict, TYPE_CHECKING
12
+ import numpy as np
13
+
14
+ from statgpu._base import BaseEstimator
15
+ from statgpu._config import Device
16
+ from statgpu.cross_validation._base import INTERCEPT_CLIP_BOUND as _INTERCEPT_CLIP_BOUND
17
+ from statgpu.linear_model._gaussian_inference import validate_cov_type, validate_hac_maxlags
18
+ from statgpu.penalties._categories import NONSMOOTH as _NONSMOOTH_PENALTIES
19
+
20
+ from ._fit_mixin import _PenalizedFitMixin
21
+ from ._inference_mixin import _PenalizedInferenceMixin
22
+ from ._predict_mixin import _PenalizedPredictMixin
23
+
24
+
25
+ class SelectivePenalty:
26
+ """Penalty wrapper that leaves the last intercept coefficient free.
27
+
28
+ Created once per fit and reused across iterations. The inner penalty,
29
+ feature count p, and backend are set via ``configure()``.
30
+ """
31
+
32
+ def __init__(self):
33
+ self._pen = None
34
+ self._p = 0
35
+ self._backend = "numpy"
36
+ self._alpha = 0.0
37
+ self._l1_ratio = 0.0
38
+
39
+ def configure(self, pen, p, backend):
40
+ self._pen = pen
41
+ self._p = p
42
+ self._backend = backend
43
+ self._alpha = float(getattr(pen, "alpha", 0.0))
44
+ self._l1_ratio = float(getattr(pen, "l1_ratio", 0.0))
45
+ self.name = pen.name
46
+
47
+ def value(self, coef):
48
+ return self._pen.value(coef[:self._p])
49
+
50
+ def proximal(self, w, step, backend=None):
51
+ b = backend or self._backend
52
+ w_feat = w[:self._p]
53
+ result_feat = self._pen.proximal(w_feat, step, backend=b)
54
+ if b == "cupy":
55
+ import cupy as cp
56
+ result = cp.empty(w.shape[0], dtype=w.dtype)
57
+ result[:self._p] = result_feat
58
+ result[-1] = cp.clip(w[-1], -_INTERCEPT_CLIP_BOUND, _INTERCEPT_CLIP_BOUND)
59
+ elif b == "torch":
60
+ import torch
61
+ result = torch.empty(w.shape[0], dtype=w.dtype, device=w.device)
62
+ result[:self._p] = result_feat
63
+ result[-1] = torch.clamp(w[-1], -_INTERCEPT_CLIP_BOUND, _INTERCEPT_CLIP_BOUND)
64
+ else:
65
+ result = np.empty(w.shape[0], dtype=w.dtype)
66
+ result[:self._p] = result_feat
67
+ result[-1] = np.clip(w[-1], -_INTERCEPT_CLIP_BOUND, _INTERCEPT_CLIP_BOUND)
68
+ return result
69
+
70
+ def _smooth_alpha(self):
71
+ pname = str(self._pen.name).lower()
72
+ if pname == "l2":
73
+ return self._alpha
74
+ if pname == "elasticnet":
75
+ return self._alpha * (1.0 - self._l1_ratio)
76
+ raise ValueError("smooth solvers only support L2/ElasticNet penalties.")
77
+
78
+ def smooth_value(self, coef):
79
+ sa = self._smooth_alpha()
80
+ active = coef[:self._p]
81
+ if self._backend == "cupy":
82
+ import cupy as cp
83
+ return 0.5 * sa * cp.sum(active * active)
84
+ if self._backend == "torch":
85
+ import torch
86
+ return 0.5 * sa * torch.sum(active * active)
87
+ return 0.5 * sa * np.sum(active * active)
88
+
89
+ def smooth_gradient(self, coef):
90
+ sa = self._smooth_alpha()
91
+ if self._backend == "cupy":
92
+ import cupy as cp
93
+ grad = cp.zeros_like(coef)
94
+ elif self._backend == "torch":
95
+ import torch
96
+ grad = torch.zeros_like(coef)
97
+ else:
98
+ grad = np.zeros_like(coef)
99
+ grad[:self._p] = sa * coef[:self._p]
100
+ return grad
101
+
102
+ def smooth_hessian(self, coef):
103
+ """Return smooth penalty Hessian as a dense diagonal matrix.
104
+
105
+ WARNING: For p > ~1000, this allocates O(p^2) memory which may cause
106
+ OOM. Consider using the diagonal representation directly when available.
107
+ """
108
+ sa = self._smooth_alpha()
109
+ if self._backend == "cupy":
110
+ import cupy as cp
111
+ diag = cp.zeros(coef.shape[0], dtype=coef.dtype)
112
+ diag[:self._p] = sa
113
+ return cp.diag(diag)
114
+ if self._backend == "torch":
115
+ import torch
116
+ diag = torch.zeros(coef.shape[0], dtype=coef.dtype, device=coef.device)
117
+ diag[:self._p] = sa
118
+ return torch.diag(diag)
119
+ diag = np.zeros(coef.shape[0], dtype=coef.dtype)
120
+ diag[:self._p] = sa
121
+ return np.diag(diag)
122
+
123
+
124
+
125
+ class PenalizedGeneralizedLinearModel(
126
+ _PenalizedFitMixin,
127
+ _PenalizedInferenceMixin,
128
+ _PenalizedPredictMixin,
129
+ BaseEstimator,
130
+ ):
131
+ """
132
+ Penalized generalized linear model with pluggable GLM loss and penalty.
133
+
134
+ Minimizes: loss(X, y, w) + penalty(w)
135
+
136
+ Parameters
137
+ ----------
138
+ loss : str, default='squared_error'
139
+ Loss function: 'squared_error', 'logistic', 'poisson', 'gamma',
140
+ 'negative_binomial', 'tweedie', 'inverse_gaussian'.
141
+ penalty : str or Penalty
142
+ Penalty type: 'l1', 'l2', 'elasticnet', 'scad', 'mcp', 'adaptive_l1',
143
+ 'group_lasso', 'group_scad', 'group_mcp', or a Penalty instance.
144
+ solver : str, default='auto'
145
+ Solver: 'auto', 'fista', 'fista_bb', 'irls', 'newton', 'lbfgs', 'exact'.
146
+ 'auto' selects the best path for the resolved backend and loss/penalty
147
+ combination (see _SOLVER_DISPATCH_TABLE).
148
+ alpha : float, default=1.0
149
+ Regularization strength.
150
+ l1_ratio : float, default=0.5
151
+ Only used when penalty='elasticnet'.
152
+ penalty_kwargs : dict, optional
153
+ Additional arguments passed to the penalty constructor.
154
+ fit_intercept : bool, default=True
155
+ Whether to calculate the intercept.
156
+ max_iter : int, default=1000
157
+ Maximum number of iterations.
158
+ tol : float, default=1e-4
159
+ Tolerance for convergence.
160
+ device : str or Device, default='auto'
161
+ Computation device: 'cpu', 'cuda', or 'auto'.
162
+ cpu_solver : str, default='fista'
163
+ CPU solver: 'fista', 'fista_bb', or 'coordinate_descent'.
164
+ lipschitz_L : float, optional
165
+ Pre-computed Lipschitz constant.
166
+ gpu_memory_cleanup : bool, default=False
167
+ If True, free GPU memory pool after fitting.
168
+
169
+ Examples
170
+ --------
171
+ # Lasso
172
+ >>> model = PenalizedLinearRegression(penalty='l1', alpha=0.1)
173
+
174
+ # Ridge
175
+ >>> model = PenalizedLinearRegression(penalty='l2', alpha=1.0)
176
+
177
+ # Elastic Net
178
+ >>> model = PenalizedLinearRegression(
179
+ ... penalty='elasticnet', alpha=0.1, l1_ratio=0.5
180
+ ... )
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ loss: str = "squared_error",
186
+ penalty: Union[str, "Penalty"] = "l1",
187
+ alpha: float = 1.0,
188
+ l1_ratio: float = 0.5,
189
+ penalty_kwargs: Optional[Dict] = None,
190
+ fit_intercept: bool = True,
191
+ max_iter: int = 1000,
192
+ tol: float = 1e-4,
193
+ device: Union[str, Device] = Device.AUTO,
194
+ n_jobs: Optional[int] = None,
195
+ cpu_solver: str = "fista",
196
+ solver: str = "auto",
197
+ lipschitz_L: Optional[float] = None,
198
+ gpu_memory_cleanup: bool = False,
199
+ compute_inference: bool = False,
200
+ inference_method: str = "debiased",
201
+ cov_type: str = "nonrobust",
202
+ hac_maxlags: Optional[int] = None,
203
+ stopping: str = "coef_delta",
204
+ lla: bool = True,
205
+ max_lla_iters: int = 50,
206
+ lla_tol: float = 1e-6,
207
+ loss_kwargs: Optional[Dict] = None,
208
+ ):
209
+ super().__init__(device=device, n_jobs=n_jobs)
210
+ self.loss = loss
211
+ self.penalty = penalty
212
+ self.alpha = alpha
213
+ self.l1_ratio = l1_ratio
214
+ self.penalty_kwargs = penalty_kwargs or {}
215
+ self.fit_intercept = fit_intercept
216
+ self.max_iter = max_iter
217
+ self.tol = tol
218
+ # Preserve original string identity for sklearn clone() compatibility
219
+ _cpu_solver = cpu_solver.lower()
220
+ self.cpu_solver = cpu_solver if cpu_solver == _cpu_solver else _cpu_solver
221
+ _solver = solver.lower()
222
+ self.solver = solver if solver == _solver else _solver
223
+ self.lipschitz_L = lipschitz_L
224
+ self.gpu_memory_cleanup = gpu_memory_cleanup
225
+ self.compute_inference = compute_inference
226
+ _inference_method = inference_method.lower()
227
+ self.inference_method = inference_method if inference_method == _inference_method else _inference_method
228
+ self.cov_type = validate_cov_type(cov_type)
229
+ self.hac_maxlags = validate_hac_maxlags(hac_maxlags)
230
+ # Preserve original object identity for sklearn clone() compatibility
231
+ _stopping = str(stopping).lower()
232
+ self.stopping = stopping if stopping == _stopping else _stopping
233
+ self.lla = lla
234
+ self.max_lla_iters = max_lla_iters
235
+ self.lla_tol = lla_tol
236
+ self.loss_kwargs = loss_kwargs or {}
237
+
238
+ # Internal state
239
+ self._penalty: Optional["Penalty"] = None
240
+ self._lla_enabled = lla
241
+ self._max_lla_iters = max_lla_iters
242
+ self._lla_tol = lla_tol
243
+ self._lla_n_iters_ = 0
244
+ self.coef_ = None
245
+ self.intercept_ = None
246
+ self.n_iter_ = 0
247
+ self._X_design = None
248
+ self._y = None
249
+ self._resid = None
250
+ self._scale = None
251
+ self._nobs = None
252
+ self._df_resid = None
253
+ self._params = None
254
+ self._bse = None
255
+ self._tvalues = None
256
+ self._pvalues = None
257
+ self._conf_int = None
258
+ self._inference_result = None
259
+ self._feature_names = None
260
+ self._design_info = None
261
+ self._formula_has_intercept = None
262
+ self._selected_solver = None
263
+ self._selected_backend_name = None
264
+ self._init_coef = None
265
+ self._inference_precomputed = False
266
+ self._precomputed_gaussian_state = None
267
+ # Simultaneous inference state
268
+ self._conf_int_simultaneous = None
269
+ self._simultaneous_enabled = False
270
+ self._debiased_M_cpu = None
271
+ self._use_intercept = None # formula-derived override; None = use fit_intercept
272
+
273
+ @property
274
+ def _effective_intercept(self):
275
+ """Return effective intercept flag. Formula path overrides via _use_intercept."""
276
+ if self._use_intercept is not None:
277
+ return self._use_intercept
278
+ return self.fit_intercept
279
+
280
+ def _resolve_penalty(self) -> "Penalty":
281
+ """Resolve penalty string or instance to a Penalty object."""
282
+ # Lazy import to avoid circular dependency
283
+ from statgpu.penalties import get_penalty, Penalty
284
+
285
+ if isinstance(self.penalty, Penalty):
286
+ return self.penalty
287
+
288
+ # Map "none"/"null" to l2 with alpha=0 (no regularization)
289
+ pen_name = str(self.penalty).lower().strip()
290
+ if pen_name in ("none", "null", ""):
291
+ return get_penalty("l2", alpha=0.0)
292
+
293
+ kwargs = {**self.penalty_kwargs, "alpha": self.alpha}
294
+ if pen_name in ("elasticnet", "en"):
295
+ kwargs["l1_ratio"] = self.l1_ratio
296
+
297
+ return get_penalty(pen_name, **kwargs)
298
+
299
+ def _resolve_loss(self):
300
+ """Resolve loss string to a GLMLoss object."""
301
+ from statgpu.glm_core import get_glm_loss
302
+
303
+ return get_glm_loss(self.loss, **self.loss_kwargs)
304
+
305
+ def _validate_solver_penalty(self):
306
+ """Validate solver/penalty combinations before backend dispatch."""
307
+ solver_name = self.solver
308
+ penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
309
+ non_smooth = _NONSMOOTH_PENALTIES
310
+ if self.solver == "exact":
311
+ if self.loss != "squared_error" or penalty_name != "l2":
312
+ raise ValueError(
313
+ "solver='exact' is only supported for squared-error L2/Ridge models."
314
+ )
315
+ return
316
+ if solver_name == "irls" and penalty_name != "l2":
317
+ raise ValueError(
318
+ "solver='irls' only supports smooth L2 penalized GLM objectives."
319
+ )
320
+ if solver_name in ("newton", "lbfgs") and penalty_name in non_smooth:
321
+ raise ValueError(
322
+ f"solver='{solver_name}' only supports smooth objectives; "
323
+ f"use solver='fista' for penalty='{penalty_name}'."
324
+ )
325
+
326
+ def _validate_inference_request(self):
327
+ """Reject unsupported penalized inference paths with a clear error.
328
+
329
+ Currently supported:
330
+ - squared_error + L2 (standard OLS inference)
331
+ - squared_error + L1/ElasticNet (debiased Lasso, cpu_ols_inference, bootstrap)
332
+ """
333
+ if not self.compute_inference:
334
+ return
335
+ penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
336
+ if self.loss == "squared_error" and penalty_name == "l2":
337
+ return
338
+ inference_method = str(getattr(self, "inference_method", "debiased")).lower()
339
+ if penalty_name in ("l1", "elasticnet", "en"):
340
+ if "debiased" in inference_method:
341
+ return
342
+ if "cpu_ols" in inference_method or "gpu_ols" in inference_method:
343
+ return
344
+ if "bootstrap" in inference_method:
345
+ return
346
+ raise NotImplementedError(
347
+ f"compute_inference=True with penalty='{penalty_name}' and "
348
+ f"loss='{self.loss}' is not supported. Use inference_method='debiased', "
349
+ f"'cpu_ols_inference', or 'bootstrap' for L1/ElasticNet, "
350
+ f"or compute_inference=False to skip inference."
351
+ )
352
+
353
+ def _clear_inference_state(self):
354
+ self._X_design = None
355
+ self._y = None
356
+ self._resid = None
357
+ self._scale = None
358
+ self._nobs = None
359
+ self._df_resid = None
360
+ self._params = None
361
+ self._bse = None
362
+ self._tvalues = None
363
+ self._pvalues = None
364
+ self._conf_int = None
365
+ self._inference_result = None
366
+
367
+ def _family_for_loss(self):
368
+ # Cache on first call (avoid re-creating on every predict/score)
369
+ cached = getattr(self, '_family_cache', None)
370
+ if cached is not None:
371
+ return cached
372
+
373
+ from statgpu.glm_core._family import (
374
+ Binomial,
375
+ Gaussian,
376
+ Poisson,
377
+ Gamma,
378
+ InverseGaussian,
379
+ NegativeBinomial,
380
+ Tweedie,
381
+ )
382
+
383
+ if self.loss == "logistic":
384
+ fam = Binomial()
385
+ elif self.loss == "poisson":
386
+ fam = Poisson()
387
+ elif self.loss == "gamma":
388
+ fam = Gamma()
389
+ elif self.loss == "inverse_gaussian":
390
+ fam = InverseGaussian()
391
+ elif self.loss == "negative_binomial":
392
+ alpha = getattr(
393
+ getattr(self, "_loss", None),
394
+ "alpha",
395
+ getattr(self, "loss_kwargs", {}).get("alpha", 1.0),
396
+ )
397
+ fam = NegativeBinomial(alpha=alpha)
398
+ elif self.loss == "tweedie":
399
+ power = getattr(
400
+ getattr(self, "_loss", None),
401
+ "power",
402
+ getattr(self, "loss_kwargs", {}).get("power", 1.5),
403
+ )
404
+ fam = Tweedie(power=power)
405
+ else:
406
+ fam = Gaussian()
407
+
408
+ self._family_cache = fam
409
+ return fam
410
+
411
+ def _column_stack(self, arrays, backend_name):
412
+ if backend_name == "cupy":
413
+ import cupy as cp
414
+ return cp.column_stack(arrays)
415
+ if backend_name == "torch":
416
+ import torch
417
+ return torch.column_stack(arrays)
418
+ return np.column_stack(arrays)
419
+
420
+ def _ones(self, n, backend_name, ref):
421
+ if backend_name == "cupy":
422
+ import cupy as cp
423
+ return cp.ones(n, dtype=ref.dtype)
424
+ if backend_name == "torch":
425
+ import torch
426
+ return torch.ones(n, dtype=ref.dtype, device=ref.device)
427
+ return np.ones(n, dtype=getattr(ref, "dtype", np.float64))
428
+
429
+ def _selective_penalty(self, p, backend_name):
430
+ """Penalty wrapper that leaves the last intercept coefficient free.
431
+
432
+ Creates a fresh instance per call to avoid thread-local singleton
433
+ conflicts in nested CV within the same thread.
434
+ """
435
+ sp = SelectivePenalty()
436
+ sp.configure(self._penalty, p, backend_name)
437
+ return sp