statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,936 @@
1
+ """
2
+ Elastic Net regression with GPU acceleration and full statistical inference.
3
+
4
+ Elastic Net combines L1 and L2 regularization:
5
+ minimize (1/(2n)) * ||y - Xw||²₂ + α * l1_ratio * ||w||₁ + 0.5 * α * (1 - l1_ratio) * ||w||²₂
6
+
7
+ where:
8
+ - α (alpha) controls the overall regularization strength
9
+ - l1_ratio controls the mix: 1.0 = Lasso, 0.0 = Ridge, 0.5 = balanced Elastic Net
10
+
11
+ Optimized implementations:
12
+ - CPU: FISTA with pre-computed Gram matrix
13
+ - GPU (CuPy): Fused kernel operations with @cp.fuse()
14
+ - GPU (Torch): torch.compile() with warm-up strategy
15
+ """
16
+
17
+ from typing import Optional, Union
18
+ import warnings
19
+ import numpy as np
20
+
21
+ from statgpu._base import BaseEstimator
22
+ from statgpu._config import Device
23
+
24
+
25
+ # ============================================================================
26
+ # CuPy Fused Kernels for Elastic Net
27
+ # ============================================================================
28
+
29
+ def _get_cupy_fused_kernels():
30
+ """Lazy load CuPy fused kernels."""
31
+ try:
32
+ import cupy as cp
33
+ except ImportError:
34
+ return None, None, None, None
35
+
36
+ @cp.fuse()
37
+ def _elastic_net_proximal(x, thresh, l2_scale):
38
+ """Fused soft thresholding with L2 scaling."""
39
+ return cp.sign(x) * cp.maximum(cp.abs(x) - thresh, 0) / l2_scale
40
+
41
+ @cp.fuse()
42
+ def _fista_momentum_update(coef, coef_old, t_old, t_new):
43
+ """Fused FISTA momentum update."""
44
+ beta = (t_old - 1) / t_new
45
+ return coef + beta * (coef - coef_old)
46
+
47
+ @cp.fuse()
48
+ def _compute_coef_delta(coef, coef_old):
49
+ """Compute absolute coefficient change."""
50
+ return cp.abs(coef - coef_old)
51
+
52
+ ELASTIC_NET_PROXIMAL_KERNEL = cp.ElementwiseKernel(
53
+ 'float64 w_tilde, float64 thresh, float64 l2_scale',
54
+ 'float64 coef',
55
+ '''
56
+ double abs_w = abs(w_tilde);
57
+ if (abs_w > thresh) {
58
+ coef = (w_tilde > 0 ? 1.0 : -1.0) * (abs_w - thresh) / l2_scale;
59
+ } else {
60
+ coef = 0.0;
61
+ }
62
+ ''',
63
+ 'elastic_net_proximal'
64
+ )
65
+
66
+ return _elastic_net_proximal, _fista_momentum_update, _compute_coef_delta, ELASTIC_NET_PROXIMAL_KERNEL
67
+
68
+
69
+ def _fit_elasticnet_cupy_optimized(X, y, alpha, l1_ratio, n_samples, n_features,
70
+ max_iter=1000, tol=1e-4, lipschitz_L=None,
71
+ stopping='coef_delta', warmup=True):
72
+ """
73
+ Fit Elastic Net using optimized CuPy operations with fused kernels.
74
+ """
75
+ import cupy as cp
76
+
77
+ # Get fused kernels
78
+ _elastic_net_proximal, _fista_momentum_update, _compute_coef_delta, _ = _get_cupy_fused_kernels()
79
+ if _elastic_net_proximal is None:
80
+ raise ImportError("CuPy not available")
81
+
82
+ # Precompute Gram matrix and cross product
83
+ XtX = X.T @ X
84
+ Xty = X.T @ y
85
+
86
+ # Parameters
87
+ l2_ratio = 1.0 - l1_ratio
88
+
89
+ # Lipschitz constant: L = lambda_max(XtX) / n
90
+ if lipschitz_L is not None:
91
+ L = float(lipschitz_L)
92
+ else:
93
+ eigvals = cp.linalg.eigvalsh(XtX)
94
+ L = float(eigvals[-1]) / n_samples
95
+
96
+ if L <= 0:
97
+ return cp.zeros(n_features), 0
98
+
99
+ step = 1.0 / L
100
+ thresh = alpha * l1_ratio * step
101
+ l2_scale = 1.0 + alpha * l2_ratio * step
102
+
103
+ # Pre-compute inverse for multiplication (faster than division)
104
+ inv_n_samples = 1.0 / n_samples
105
+ inv_l2_scale = 1.0 / l2_scale
106
+
107
+ # Allocate buffers (reuse to minimize allocation overhead)
108
+ coef = cp.zeros(n_features, dtype=X.dtype)
109
+ y_k = cp.zeros(n_features, dtype=X.dtype)
110
+ coef_old = cp.zeros(n_features, dtype=X.dtype)
111
+ grad = cp.empty(n_features, dtype=X.dtype)
112
+ w_tilde = cp.empty(n_features, dtype=X.dtype)
113
+
114
+ # FISTA state
115
+ t_k = 1.0
116
+ n_iter = 0
117
+
118
+ # Warm-up: Call fused kernel once to trigger JIT compilation
119
+ if warmup:
120
+ _ = _elastic_net_proximal(w_tilde, thresh, l2_scale)
121
+ _ = (1.0 + cp.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
122
+
123
+ for iteration in range(max_iter):
124
+ # Store old coefficients for convergence check
125
+ coef_old[:] = coef
126
+
127
+ # Gradient step: grad = (XtX @ y_k - Xty) / n
128
+ grad = XtX @ y_k
129
+ grad -= Xty
130
+ grad *= inv_n_samples
131
+
132
+ # Proximal step: w_tilde = y_k - step * grad
133
+ w_tilde = y_k - step * grad
134
+
135
+ # Soft thresholding with L2 scaling (using fused kernel)
136
+ coef = _elastic_net_proximal(w_tilde, thresh, l2_scale)
137
+
138
+ # FISTA momentum update
139
+ t_new = (1.0 + cp.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
140
+ beta = (t_k - 1.0) / t_new
141
+ y_k = coef + beta * (coef - coef_old)
142
+ t_k = t_new
143
+
144
+ n_iter = iteration + 1
145
+
146
+ # Convergence check
147
+ if stopping == 'kkt':
148
+ kkt_grad = XtX @ coef
149
+ kkt_grad -= Xty
150
+ kkt_grad *= inv_n_samples
151
+
152
+ grad_l2 = alpha * l2_ratio * coef
153
+ sign_coef = cp.sign(coef)
154
+ sign_coef[coef == 0] = 0
155
+
156
+ kkt_violation = cp.maximum(
157
+ cp.abs(kkt_grad + grad_l2 + alpha * l1_ratio * sign_coef),
158
+ cp.maximum(cp.abs(kkt_grad + grad_l2) - alpha * l1_ratio, 0)
159
+ )
160
+ violation = float(cp.max(kkt_violation))
161
+ else:
162
+ delta = cp.abs(coef - coef_old)
163
+ violation = float(cp.max(delta))
164
+
165
+ if violation < tol:
166
+ break
167
+
168
+ return coef, n_iter
169
+
170
+
171
+ # ============================================================================
172
+ # Torch Compiled Kernels for Elastic Net
173
+ # ============================================================================
174
+
175
+ def _get_torch_compiled_proximal():
176
+ """Lazy load torch.compile proximal operator."""
177
+ try:
178
+ import torch
179
+ except ImportError:
180
+ return None
181
+
182
+ def _elastic_net_proximal_torch(w_tilde, thresh, l2_scale):
183
+ """Soft thresholding with L2 scaling for Elastic Net."""
184
+ return torch.sign(w_tilde) * torch.maximum(
185
+ torch.abs(w_tilde) - thresh,
186
+ torch.tensor(0.0, device=w_tilde.device, dtype=w_tilde.dtype)
187
+ ) / l2_scale
188
+
189
+ # Compile the proximal operator
190
+ try:
191
+ torch._dynamo.config.suppress_errors = True
192
+ torch._dynamo.config.guard_immutable_object = False
193
+ _elastic_net_proximal_compiled = torch.compile(
194
+ _elastic_net_proximal_torch, mode='reduce-overhead'
195
+ )
196
+ except (AttributeError, RuntimeError):
197
+ _elastic_net_proximal_compiled = _elastic_net_proximal_torch
198
+
199
+ return _elastic_net_proximal_compiled
200
+
201
+
202
+ def _fit_elasticnet_torch_optimized(X, y, alpha, l1_ratio, n_samples, n_features,
203
+ max_iter=1000, tol=1e-4, lipschitz_L=None,
204
+ stopping='coef_delta', warmup=True):
205
+ """
206
+ Fit Elastic Net using optimized PyTorch operations with torch.compile().
207
+ """
208
+ import torch
209
+
210
+ # Get compiled proximal operator
211
+ _elastic_net_proximal_compiled = _get_torch_compiled_proximal()
212
+ if _elastic_net_proximal_compiled is None:
213
+ raise ImportError("Torch not available")
214
+
215
+ # Precompute Gram matrix and cross product
216
+ XtX = X.T @ X
217
+ Xty = X.T @ y
218
+
219
+ # Parameters
220
+ l2_ratio = 1.0 - l1_ratio
221
+
222
+ # Lipschitz constant: L = lambda_max(XtX) / n
223
+ if lipschitz_L is not None:
224
+ L = float(lipschitz_L)
225
+ else:
226
+ eigvals = torch.linalg.eigvalsh(XtX)
227
+ L = float(eigvals[-1]) / n_samples
228
+
229
+ if L <= 0:
230
+ return torch.zeros(n_features, device=X.device, dtype=X.dtype), 0
231
+
232
+ step = 1.0 / L
233
+ thresh = alpha * l1_ratio * step
234
+ l2_scale = 1.0 + alpha * l2_ratio * step
235
+
236
+ # Pre-compute inverse for multiplication (faster than division)
237
+ inv_n_samples = 1.0 / n_samples
238
+ inv_l2_scale = 1.0 / l2_scale
239
+
240
+ # Allocate buffers (reuse to minimize allocation overhead)
241
+ coef = torch.zeros(n_features, dtype=X.dtype, device=X.device)
242
+ y_k = torch.zeros(n_features, dtype=X.dtype, device=X.device)
243
+ coef_old = torch.zeros(n_features, dtype=X.dtype, device=X.device)
244
+ grad = torch.empty(n_features, dtype=X.dtype, device=X.device)
245
+ w_tilde = torch.empty(n_features, dtype=X.dtype, device=X.device)
246
+
247
+ # FISTA state
248
+ t_k = 1.0
249
+ n_iter = 0
250
+
251
+ # Warm-up: Call compiled kernel once to trigger JIT compilation
252
+ if warmup:
253
+ _ = _elastic_net_proximal_compiled(w_tilde, thresh, l2_scale)
254
+ _ = (1.0 + torch.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
255
+
256
+ for iteration in range(max_iter):
257
+ # Store old coefficients for convergence check
258
+ coef_old.copy_(coef)
259
+
260
+ # Gradient step: grad = (XtX @ y_k - Xty) / n
261
+ torch.matmul(XtX, y_k, out=grad)
262
+ grad -= Xty
263
+ grad *= inv_n_samples
264
+
265
+ # Proximal step: w_tilde = y_k - step * grad
266
+ torch.subtract(y_k, grad, alpha=step, out=w_tilde)
267
+
268
+ # Soft thresholding with L2 scaling (using compiled fused kernel)
269
+ coef = _elastic_net_proximal_compiled(w_tilde, thresh, l2_scale)
270
+
271
+ # FISTA momentum update
272
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * t_k * t_k)) * 0.5
273
+ beta = (t_k - 1.0) / t_new
274
+ y_k = coef + beta * (coef - coef_old)
275
+ t_k = t_new
276
+
277
+ n_iter = iteration + 1
278
+
279
+ # Convergence check
280
+ if stopping == 'kkt':
281
+ kkt_grad = torch.matmul(XtX, coef, out=grad)
282
+ kkt_grad -= Xty
283
+ kkt_grad *= inv_n_samples
284
+
285
+ grad_l2 = alpha * l2_ratio * coef
286
+ sign_coef = torch.sign(coef)
287
+ sign_coef[coef == 0] = 0
288
+
289
+ kkt_violation = torch.maximum(
290
+ torch.abs(kkt_grad + grad_l2 + alpha * l1_ratio * sign_coef),
291
+ torch.maximum(
292
+ torch.abs(kkt_grad + grad_l2) - alpha * l1_ratio,
293
+ torch.tensor(0.0, device=X.device)
294
+ )
295
+ )
296
+ violation = float(torch.max(kkt_violation).item())
297
+ else:
298
+ delta = torch.abs(coef - coef_old)
299
+ violation = float(torch.max(delta).item())
300
+
301
+ if violation < tol:
302
+ break
303
+
304
+ return coef, n_iter
305
+
306
+
307
+ # ============================================================================
308
+ # Elastic Net Estimator Class
309
+ # ============================================================================
310
+
311
+ class ElasticNet(BaseEstimator):
312
+ """
313
+ Elastic Net regression with GPU acceleration.
314
+
315
+ Elastic Net combines L1 (Lasso) and L2 (Ridge) regularization, controlled by
316
+ the `l1_ratio` parameter. This provides:
317
+ - Feature selection from L1 (sparse solutions)
318
+ - Grouping effect from L2 (handles correlated features)
319
+
320
+ Parameters
321
+ ----------
322
+ alpha : float, default=1.0
323
+ Regularization strength. Larger values specify stronger regularization.
324
+ Must be non-negative.
325
+ l1_ratio : float, default=0.5
326
+ Elastic Net mixing parameter, between 0 and 1 inclusive.
327
+ - l1_ratio = 1: L1 penalty only (Lasso)
328
+ - l1_ratio = 0: L2 penalty only (Ridge)
329
+ - 0 < l1_ratio < 1: Combination of L1 and L2 penalties
330
+ fit_intercept : bool, default=True
331
+ Whether to calculate the intercept.
332
+ max_iter : int, default=1000
333
+ Maximum number of iterations for the solver.
334
+ tol : float, default=1e-4
335
+ Tolerance for convergence.
336
+ stopping : str, default='coef_delta'
337
+ Stopping criterion: 'coef_delta' or 'kkt'.
338
+ device : str or Device, default='auto'
339
+ Computation device: 'cpu', 'cuda', or 'auto'.
340
+ solver : str, default='fista'
341
+ GPU optimization algorithm: 'fista' or 'admm'.
342
+ Note: ADMM not yet implemented for Elastic Net.
343
+ cpu_solver : str, default='fista'
344
+ CPU optimization algorithm: 'fista' or 'coordinate_descent'.
345
+ Note: coordinate_descent not yet implemented for Elastic Net.
346
+ lipschitz_L : float, optional
347
+ Pre-computed Lipschitz constant. If not provided, will be estimated.
348
+ gpu_memory_cleanup : bool, default=False
349
+ If True, free GPU memory pool after fitting.
350
+
351
+ Attributes
352
+ ----------
353
+ coef_ : ndarray of shape (n_features,)
354
+ Estimated coefficients.
355
+ intercept_ : float
356
+ Independent term.
357
+ n_iter_ : int
358
+ Number of iterations run.
359
+
360
+ See Also
361
+ --------
362
+ Lasso : Lasso regression with L1 regularization.
363
+ Ridge : Ridge regression with L2 regularization.
364
+
365
+ Notes
366
+ -----
367
+ The objective function is:
368
+
369
+ (1 / (2 * n_samples)) * ||y - Xw||²₂ + α * l1_ratio * ||w||₁ + 0.5 * α * (1 - l1_ratio) * ||w||²₂
370
+
371
+ References
372
+ ----------
373
+ .. [1] Zou, H., & Hastie, T. (2005). Regularization and variable selection
374
+ via the elastic net. Journal of the Royal Statistical Society:
375
+ Series B, 67(2), 301-320.
376
+ .. [2] Beck, A., & Teboulle, M. (2009). A fast iterative shrinkage-thresholding
377
+ algorithm for linear inverse problems. SIAM Journal on Imaging Sciences,
378
+ 2(1), 183-202.
379
+
380
+ Examples
381
+ --------
382
+ >>> import numpy as np
383
+ >>> from statgpu.linear_model import ElasticNet
384
+ >>> X = np.random.randn(100, 10)
385
+ >>> y = X @ np.random.randn(10) + np.random.randn(100)
386
+ >>> model = ElasticNet(alpha=1.0, l1_ratio=0.5)
387
+ >>> model.fit(X, y)
388
+ >>> print(model.coef_)
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ alpha: float = 1.0,
394
+ l1_ratio: float = 0.5,
395
+ fit_intercept: bool = True,
396
+ max_iter: int = 1000,
397
+ tol: float = 1e-4,
398
+ stopping: str = "coef_delta",
399
+ device: Union[str, Device] = Device.AUTO,
400
+ n_jobs: Optional[int] = None,
401
+ solver: str = "fista",
402
+ cpu_solver: str = "fista",
403
+ lipschitz_L: Optional[float] = None,
404
+ gpu_memory_cleanup: bool = False,
405
+ ):
406
+ super().__init__(device=device, n_jobs=n_jobs)
407
+ self.alpha = alpha
408
+ self.l1_ratio = l1_ratio
409
+ self.fit_intercept = fit_intercept
410
+ self.max_iter = max_iter
411
+ self.tol = tol
412
+ self.stopping = stopping.lower()
413
+ self.solver = solver.lower()
414
+ self.cpu_solver = cpu_solver.lower()
415
+ self.lipschitz_L = lipschitz_L
416
+ self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
417
+
418
+ self.coef_ = None
419
+ self.intercept_ = None
420
+ self.n_iter_ = 0
421
+
422
+ # Internal storage
423
+ self._params = None
424
+ self._scale = None
425
+ self._df_resid = None
426
+ self._nobs = None
427
+ self._X_design = None
428
+ self._resid = None
429
+
430
+ def fit(self, X, y, sample_weight=None, initial_coef=None):
431
+ """
432
+ Fit Elastic Net model.
433
+
434
+ Parameters
435
+ ----------
436
+ X : array-like of shape (n_samples, n_features)
437
+ Training data.
438
+ y : array-like of shape (n_samples,)
439
+ Target values.
440
+ sample_weight : array-like of shape (n_samples,), optional
441
+ Sample weights.
442
+ initial_coef : array-like of shape (n_features,), optional
443
+ Initial coefficient vector for warm-start. When fitting along a
444
+ regularization path (alphas from large to small), passing the
445
+ previous solution can significantly reduce iterations.
446
+
447
+ Returns
448
+ -------
449
+ self : ElasticNet
450
+ Fitted estimator.
451
+ """
452
+ device = self._get_compute_device()
453
+ backend = self._get_backend(backend="auto")
454
+ backend_name = backend.name
455
+
456
+ X_arr = self._to_array(X, backend=backend_name)
457
+ y_arr = self._to_array(y, backend=backend_name)
458
+
459
+ # Route to appropriate backend
460
+ if backend_name == "torch":
461
+ self._fit_torch(X_arr, y_arr, sample_weight)
462
+ elif device == Device.CUDA:
463
+ self._fit_gpu(X_arr, y_arr, sample_weight)
464
+ else:
465
+ self._fit_cpu(X_arr, y_arr, sample_weight, initial_coef=initial_coef)
466
+
467
+ self._fitted = True
468
+ return self
469
+
470
+ def predict(self, X):
471
+ """
472
+ Predict using Elastic Net model.
473
+
474
+ Parameters
475
+ ----------
476
+ X : array-like of shape (n_samples, n_features)
477
+ Test data.
478
+
479
+ Returns
480
+ -------
481
+ y_pred : ndarray of shape (n_samples,)
482
+ Predicted values.
483
+ """
484
+ if self.coef_ is None:
485
+ raise RuntimeError("Model has not been fitted yet.")
486
+
487
+ device = self._get_compute_device()
488
+ if device == Device.CUDA:
489
+ import cupy as cp
490
+ X_gpu = cp.asarray(self._to_array(X, Device.CUDA))
491
+ coef_gpu = cp.asarray(self.coef_)
492
+ y_pred = X_gpu @ coef_gpu
493
+ if self.fit_intercept:
494
+ y_pred += cp.asarray(self.intercept_, dtype=coef_gpu.dtype)
495
+ return y_pred
496
+ if device == Device.TORCH:
497
+ import torch
498
+ X_torch = self._to_array(X, Device.TORCH, backend="torch").to(torch.float64)
499
+ coef_torch = torch.as_tensor(self.coef_, dtype=X_torch.dtype, device=X_torch.device)
500
+ y_pred = X_torch @ coef_torch
501
+ if self.fit_intercept:
502
+ y_pred = y_pred + torch.as_tensor(
503
+ self.intercept_, dtype=y_pred.dtype, device=y_pred.device
504
+ )
505
+ return y_pred
506
+
507
+ X = np.asarray(X)
508
+ y_pred = X @ self.coef_
509
+ if self.fit_intercept:
510
+ y_pred += self.intercept_
511
+ return y_pred
512
+
513
+ def score(self, X, y):
514
+ """
515
+ Return the coefficient of determination R².
516
+
517
+ Parameters
518
+ ----------
519
+ X : array-like of shape (n_samples, n_features)
520
+ Test data.
521
+ y : array-like of shape (n_samples,)
522
+ True values.
523
+
524
+ Returns
525
+ -------
526
+ r2 : float
527
+ R² score.
528
+ """
529
+ y_pred = self.predict(X)
530
+ device = self._get_compute_device()
531
+ if device == Device.CUDA:
532
+ import cupy as cp
533
+
534
+ yb = cp.asarray(self._to_array(y, Device.CUDA))
535
+ ss_res = cp.sum((yb - y_pred) ** 2)
536
+ ss_tot = cp.sum((yb - cp.mean(yb)) ** 2)
537
+ return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
538
+ if device == Device.TORCH:
539
+ import torch
540
+
541
+ yb = self._to_array(y, Device.TORCH, backend="torch").to(y_pred.dtype)
542
+ ss_res = torch.sum((yb - y_pred) ** 2)
543
+ ss_tot = torch.sum((yb - torch.mean(yb)) ** 2)
544
+ return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
545
+ y_pred = np.asarray(y_pred)
546
+ y = self._to_numpy(y)
547
+ ss_res = np.sum((y - y_pred) ** 2)
548
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
549
+ return 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
550
+
551
+ def _soft_threshold(self, x, gamma):
552
+ """Standard soft thresholding operator for Lasso."""
553
+ return np.sign(x) * np.maximum(np.abs(x) - gamma, 0)
554
+
555
+ def _soft_threshold_elastic(self, x, gamma, l2_scale):
556
+ """
557
+ Elastic Net soft thresholding operator.
558
+
559
+ Applies soft thresholding then divides by L2 scaling factor.
560
+ This is the proximal operator for L1 + L2 regularization.
561
+
562
+ Parameters
563
+ ----------
564
+ x : ndarray
565
+ Input array
566
+ gamma : float
567
+ Threshold parameter (alpha * l1_ratio * step)
568
+ l2_scale : float
569
+ L2 scaling factor (1 + alpha * (1 - l1_ratio) * step)
570
+
571
+ Returns
572
+ -------
573
+ ndarray
574
+ Soft thresholded and scaled result
575
+ """
576
+ return self._soft_threshold(x, gamma) / l2_scale
577
+
578
+ def _fit_cpu(self, X, y, sample_weight=None, initial_coef=None):
579
+ """
580
+ Fit using CPU FISTA solver with optimized implementation.
581
+
582
+ Elastic Net proximal gradient update:
583
+ grad = (XtX @ w - Xty) / n # gradient of RSS only
584
+ w = soft_threshold(w - step*grad, alpha*l1_ratio*step) / (1 + alpha*(1-l1_ratio)*step)
585
+
586
+ Note: L2 regularization is handled in the proximal step, NOT in the gradient.
587
+
588
+ Parameters
589
+ ----------
590
+ X : ndarray
591
+ Training data (n_samples, n_features).
592
+ y : ndarray
593
+ Target values (n_samples,).
594
+ sample_weight : ndarray, optional
595
+ Sample weights.
596
+ initial_coef : ndarray, optional
597
+ Initial coefficient vector for warm-start. If provided, avoids starting from zero
598
+ and can significantly speed up convergence along a regularization path.
599
+ """
600
+ X = np.asarray(X)
601
+ y = np.asarray(y)
602
+
603
+ n_samples, n_features = X.shape
604
+ self._nobs = n_samples
605
+
606
+ if sample_weight is not None:
607
+ sample_weight = np.asarray(sample_weight)
608
+ sqrt_sw = np.sqrt(sample_weight)
609
+ X = X * sqrt_sw[:, np.newaxis]
610
+ y = y * sqrt_sw
611
+
612
+ if self.fit_intercept:
613
+ X_mean = np.mean(X, axis=0)
614
+ y_mean = np.mean(y)
615
+ # Memory-efficient centering: avoid creating full X_centered (n×p) matrix
616
+ XtX = X.T @ X - n_samples * np.outer(X_mean, X_mean)
617
+ Xty = X.T @ y - n_samples * X_mean * y_mean
618
+ else:
619
+ y_mean = 0.0
620
+ XtX = X.T @ X
621
+ Xty = X.T @ y
622
+
623
+ if Xty.ndim == 0:
624
+ Xty = Xty.reshape(1)
625
+ if Xty.ndim == 1:
626
+ Xty = Xty.reshape(-1, 1)
627
+ Xty_flat = Xty.flatten()
628
+
629
+ # Elastic Net parameters
630
+ alpha = float(self.alpha)
631
+ l1_ratio = float(self.l1_ratio)
632
+ l2_ratio = 1.0 - l1_ratio
633
+
634
+ # Lipschitz constant: L = lambda_max(XtX)/n (RSS only, L2 is handled in proximal step)
635
+ if self.lipschitz_L is not None:
636
+ L = float(self.lipschitz_L)
637
+ else:
638
+ try:
639
+ eig_max = np.linalg.eigvalsh(XtX)[-1]
640
+ L = float(eig_max / n_samples)
641
+ except Exception:
642
+ # Frobenius norm squared / n = trace(XtX) / n = sum(X_centered^2) / n
643
+ L = float(np.trace(XtX) / n_samples)
644
+
645
+ if L <= 0:
646
+ # Degenerate case: apply pure proximal operator
647
+ thresh = alpha * l1_ratio
648
+ l2_scale = 1.0 + alpha * l2_ratio
649
+ coef = self._soft_threshold_elastic(np.zeros(n_features), thresh, l2_scale)
650
+ self.n_iter_ = 0
651
+ else:
652
+ step = 1.0 / L
653
+
654
+ # Elastic Net proximal parameters
655
+ thresh = alpha * l1_ratio * step
656
+ l2_scale = 1.0 + alpha * l2_ratio * step
657
+ inv_l2_scale = 1.0 / l2_scale
658
+ inv_n_samples = 1.0 / n_samples
659
+
660
+ # FISTA variables - use warm-start if available
661
+ if initial_coef is not None and len(initial_coef) == n_features:
662
+ coef = np.asarray(initial_coef, dtype=np.float64).copy()
663
+ else:
664
+ coef = np.zeros(n_features)
665
+ y_k = coef.copy()
666
+ t_k = 1.0
667
+
668
+ # Pre-allocate buffers to reduce allocation overhead
669
+ coef_old = np.empty_like(coef)
670
+ grad = np.empty_like(coef)
671
+ w_tilde = np.empty_like(coef)
672
+ delta = np.empty_like(coef)
673
+
674
+ for iteration in range(self.max_iter):
675
+ # Store old coefficients (in-place copy)
676
+ coef_old[:] = coef
677
+
678
+ # Gradient of RSS ONLY (L2 is handled in proximal step)
679
+ # grad = (XtX @ y_k - Xty) / n_samples
680
+ np.matmul(XtX, y_k, out=grad)
681
+ grad -= Xty_flat
682
+ grad *= inv_n_samples
683
+
684
+ # Proximal gradient step with Elastic Net soft thresholding
685
+ # w_tilde = y_k - step * grad
686
+ np.subtract(y_k, step * grad, out=w_tilde)
687
+
688
+ # coef = soft_threshold(w_tilde, thresh) / l2_scale
689
+ # Using vectorized operations with pre-computed inv_l2_scale
690
+ np.abs(w_tilde, out=delta)
691
+ np.maximum(delta - thresh, 0, out=delta)
692
+ coef[:] = np.sign(w_tilde) * delta * inv_l2_scale
693
+
694
+ # Momentum update (FISTA)
695
+ sqrt_arg = 1.0 + 4.0 * t_k * t_k
696
+ t_new = (1.0 + np.sqrt(sqrt_arg)) * 0.5
697
+ beta = (t_k - 1.0) / t_new
698
+ # y_k = coef + beta * (coef - coef_old)
699
+ np.subtract(coef, coef_old, out=y_k)
700
+ y_k *= beta
701
+ y_k += coef
702
+ t_k = t_new
703
+
704
+ # Convergence test - use L-infinity norm of coefficient change
705
+ np.abs(coef - coef_old, out=delta)
706
+ violation = float(np.max(delta))
707
+
708
+ if violation < self.tol:
709
+ self.n_iter_ = iteration + 1
710
+ break
711
+ else:
712
+ self.n_iter_ = self.max_iter
713
+
714
+ # Compute intercept
715
+ if self.fit_intercept:
716
+ self.intercept_ = float(y_mean - X_mean @ coef)
717
+ self.coef_ = coef
718
+ self._params = np.concatenate([[self.intercept_], self.coef_])
719
+ else:
720
+ self.intercept_ = 0.0
721
+ self.coef_ = coef
722
+ self._params = coef.copy()
723
+
724
+ self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
725
+
726
+ def _soft_threshold_cupy(self, x, gamma, l2_scale=None):
727
+ """Soft thresholding operator for CuPy arrays."""
728
+ import cupy as cp
729
+ if l2_scale is not None:
730
+ return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0) / l2_scale
731
+ return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0)
732
+
733
+ def _cleanup_cuda_memory(self):
734
+ """Free CuPy memory pool."""
735
+ if not self.gpu_memory_cleanup:
736
+ return
737
+ try:
738
+ import cupy as cp
739
+ cp.get_default_memory_pool().free_all_blocks()
740
+ cp.get_default_pinned_memory_pool().free_all_blocks()
741
+ except Exception:
742
+ pass
743
+
744
+ def _fit_gpu(self, X, y, sample_weight=None):
745
+ """
746
+ Fit using GPU (CuPy) with optimized FISTA solver and fused kernels.
747
+ """
748
+ import cupy as cp
749
+
750
+ if self.solver not in ("fista",):
751
+ raise ValueError("Elastic Net currently only supports 'fista' solver")
752
+
753
+ n_samples, n_features = X.shape
754
+ self._nobs = n_samples
755
+
756
+ # Ensure CuPy arrays
757
+ X = cp.asarray(X)
758
+ y = cp.asarray(y)
759
+
760
+ if sample_weight is not None:
761
+ sample_weight = cp.asarray(sample_weight)
762
+ sqrt_sw = cp.sqrt(sample_weight)
763
+ X = X * sqrt_sw[:, cp.newaxis]
764
+ y = y * sqrt_sw
765
+
766
+ # Ensure vector y on GPU
767
+ y = y.reshape(-1)
768
+
769
+ # Center for intercept
770
+ if self.fit_intercept:
771
+ X_mean = cp.mean(X, axis=0)
772
+ y_mean = cp.mean(y)
773
+ X_centered = X - X_mean
774
+ y_centered = y - y_mean
775
+ else:
776
+ X_centered = X
777
+ y_mean = cp.array(0.0, dtype=X.dtype)
778
+ y_centered = y
779
+
780
+ # Use optimized implementation with fused kernels
781
+ coef, self.n_iter_ = _fit_elasticnet_cupy_optimized(
782
+ X=X_centered,
783
+ y=y_centered,
784
+ alpha=float(self.alpha),
785
+ l1_ratio=float(self.l1_ratio),
786
+ n_samples=n_samples,
787
+ n_features=n_features,
788
+ max_iter=self.max_iter,
789
+ tol=self.tol,
790
+ lipschitz_L=self.lipschitz_L,
791
+ stopping=self.stopping,
792
+ warmup=True # Enable warm-up to avoid JIT overhead
793
+ )
794
+
795
+ # Build full coefficients
796
+ if self.fit_intercept:
797
+ intercept_gpu = y_mean - X_mean @ coef
798
+ coef_full = cp.concatenate([intercept_gpu.reshape(1), coef])
799
+ else:
800
+ coef_full = coef
801
+
802
+ # Transfer to CPU
803
+ coef_full_np = coef_full.get()
804
+
805
+ if self.fit_intercept:
806
+ self.intercept_ = float(coef_full_np[0])
807
+ self.coef_ = coef_full_np[1:]
808
+ self._params = coef_full_np
809
+ else:
810
+ self.intercept_ = 0.0
811
+ self.coef_ = coef_full_np
812
+ self._params = coef_full_np
813
+
814
+ self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
815
+
816
+ # Cleanup
817
+ self._cleanup_cuda_memory()
818
+
819
+ def _soft_threshold_elastic_cupy(self, x, gamma, l2_scale):
820
+ """Elastic Net soft thresholding for CuPy."""
821
+ import cupy as cp
822
+ return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0) / l2_scale
823
+
824
+ def _soft_threshold_torch(self, x, gamma, l2_scale=None):
825
+ """Soft thresholding operator for Torch tensors."""
826
+ import torch
827
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
828
+ if l2_scale is not None:
829
+ return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, zero) / l2_scale
830
+ return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, zero)
831
+
832
+ def _soft_threshold_elastic_torch(self, x, gamma, l2_scale):
833
+ """Elastic Net soft thresholding for Torch."""
834
+ import torch
835
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
836
+ return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, zero) / l2_scale
837
+
838
+ def _cleanup_torch_memory(self):
839
+ """Free Torch memory pool."""
840
+ if not self.gpu_memory_cleanup:
841
+ return
842
+ try:
843
+ import torch
844
+ if torch.cuda.is_available():
845
+ torch.cuda.empty_cache()
846
+ except Exception:
847
+ pass
848
+
849
+ def _fit_torch(self, X, y, sample_weight=None):
850
+ """
851
+ Fit using Torch GPU with optimized FISTA solver and torch.compile().
852
+ """
853
+ import torch
854
+
855
+ if self.solver not in ("fista",):
856
+ raise ValueError("Torch backend currently only supports 'fista' solver")
857
+
858
+ n_samples, n_features = X.shape
859
+ self._nobs = n_samples
860
+
861
+ # Ensure Torch tensors on GPU
862
+ if not isinstance(X, torch.Tensor):
863
+ X = torch.from_numpy(X).to('cuda')
864
+ if not isinstance(y, torch.Tensor):
865
+ y = torch.from_numpy(y).to('cuda')
866
+ if y.dtype != torch.float64:
867
+ y = y.to(torch.float64)
868
+ if X.dtype != torch.float64:
869
+ X = X.to(torch.float64)
870
+
871
+ if sample_weight is not None:
872
+ if not isinstance(sample_weight, torch.Tensor):
873
+ sample_weight = torch.from_numpy(sample_weight).to('cuda')
874
+ sqrt_sw = torch.sqrt(sample_weight)
875
+ X = X * sqrt_sw[:, None]
876
+ y = y * sqrt_sw
877
+
878
+ # Ensure vector y on GPU
879
+ y = y.reshape(-1)
880
+
881
+ # Center for intercept
882
+ if self.fit_intercept:
883
+ X_mean = torch.mean(X, dim=0)
884
+ y_mean = torch.mean(y)
885
+ X_centered = X - X_mean
886
+ y_centered = y - y_mean
887
+ else:
888
+ X_centered = X
889
+ y_mean = torch.tensor(0.0, dtype=X.dtype, device=X.device)
890
+ y_centered = y
891
+
892
+ # Use optimized implementation with torch.compile()
893
+ coef, self.n_iter_ = _fit_elasticnet_torch_optimized(
894
+ X=X_centered,
895
+ y=y_centered,
896
+ alpha=float(self.alpha),
897
+ l1_ratio=float(self.l1_ratio),
898
+ n_samples=n_samples,
899
+ n_features=n_features,
900
+ max_iter=self.max_iter,
901
+ tol=self.tol,
902
+ lipschitz_L=self.lipschitz_L,
903
+ stopping=self.stopping,
904
+ warmup=True # Enable warm-up to avoid JIT overhead
905
+ )
906
+
907
+ # Build full coefficients
908
+ if self.fit_intercept:
909
+ intercept_torch = y_mean - X_mean @ coef
910
+ coef_full = torch.cat([intercept_torch.reshape(1), coef])
911
+ else:
912
+ coef_full = coef
913
+
914
+ # Transfer to CPU
915
+ coef_full_np = coef_full.cpu().numpy()
916
+
917
+ if self.fit_intercept:
918
+ self.intercept_ = float(coef_full_np[0])
919
+ self.coef_ = coef_full_np[1:]
920
+ self._params = coef_full_np
921
+ else:
922
+ self.intercept_ = 0.0
923
+ self.coef_ = coef_full_np
924
+ self._params = coef_full_np
925
+
926
+ self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
927
+
928
+ # Cleanup
929
+ self._cleanup_torch_memory()
930
+
931
+
932
+ # =============================================================================
933
+ # V9 thin wrapper
934
+ # =============================================================================
935
+
936
+ from statgpu.linear_model.penalized._penalized_linear import PenalizedLinearRegression as _PenalizedLinearRegression