statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4876 @@
1
+ """
2
+ Lasso regression with full statistical inference and GPU support.
3
+ """
4
+
5
+ from collections import OrderedDict
6
+ import hashlib
7
+ from typing import Any, Dict, Optional, Tuple, Union
8
+ import os
9
+ import warnings
10
+ import numpy as np
11
+ from scipy import stats
12
+ from scipy.stats import norm as _norm_dist
13
+
14
+ try:
15
+ from numba import njit
16
+
17
+ _NUMBA_AVAILABLE = True
18
+ except Exception:
19
+ njit = None
20
+ _NUMBA_AVAILABLE = False
21
+
22
+ from statgpu._base import BaseEstimator
23
+ from statgpu._config import Device
24
+ from statgpu.linear_model._cv_base import CVEstimatorBase
25
+ from statgpu.backends import get_backend
26
+ from statgpu.inference._distributions_backend import (
27
+ norm,
28
+ t,
29
+ )
30
+
31
+
32
+ _NUMBA_CD_DISABLED = str(os.getenv("STATGPU_DISABLE_NUMBA_CD", "0")).strip().lower() in (
33
+ "1",
34
+ "true",
35
+ "yes",
36
+ "on",
37
+ )
38
+
39
+ _LASSO_CV_ALPHA_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_CV_CACHE_SIZE", "64"))
40
+ _LASSO_CV_ALPHA_CACHE: "OrderedDict[Tuple[Any, ...], Dict[str, Any]]" = OrderedDict()
41
+ _LASSO_DEBIASED_M_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_DEBIASED_M_CACHE_SIZE", "16"))
42
+ _LASSO_DEBIASED_M_CACHE: "OrderedDict[Tuple[Any, ...], np.ndarray]" = OrderedDict()
43
+ _LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK = 1024
44
+
45
+
46
+ # ============================================================================
47
+ # CuPy Fused Kernels for Lasso - Now implemented as Lasso class methods
48
+ # See Lasso._get_cupy_fused_kernels() for details.
49
+ # ============================================================================
50
+
51
+
52
+ def _debiased_m_cache_get(key):
53
+ val = _LASSO_DEBIASED_M_CACHE.get(key)
54
+ if val is not None:
55
+ _LASSO_DEBIASED_M_CACHE.move_to_end(key)
56
+ return val
57
+
58
+
59
+ def _debiased_m_cache_put(key, value):
60
+ _LASSO_DEBIASED_M_CACHE[key] = value
61
+ _LASSO_DEBIASED_M_CACHE.move_to_end(key)
62
+ while len(_LASSO_DEBIASED_M_CACHE) > _LASSO_DEBIASED_M_CACHE_MAXSIZE:
63
+ _LASSO_DEBIASED_M_CACHE.popitem(last=False)
64
+
65
+
66
+ def _debiased_m_key_from_numpy_design(
67
+ X: np.ndarray,
68
+ *,
69
+ n: int,
70
+ p: int,
71
+ lam_nw: float,
72
+ tol: float,
73
+ ):
74
+ X_cache = np.asarray(X)
75
+ if not X_cache.flags["C_CONTIGUOUS"]:
76
+ X_cache = np.ascontiguousarray(X_cache)
77
+ h = hashlib.blake2b(digest_size=32)
78
+ h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
79
+ h.update(str(X_cache.dtype).encode("utf-8"))
80
+ h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
81
+ h.update(X_cache.view(np.uint8).tobytes())
82
+ return h.hexdigest()
83
+
84
+
85
+ def _debiased_m_key_from_sample(
86
+ *,
87
+ n: int,
88
+ p: int,
89
+ dtype_name: str,
90
+ sample_block: np.ndarray,
91
+ lam_nw: float,
92
+ tol: float,
93
+ ):
94
+ """Generate cache key for debiased M matrix from a sample block of X.
95
+
96
+ This is used for Torch backend where we don't want to hash the entire matrix.
97
+ """
98
+ h = hashlib.blake2b(digest_size=32)
99
+ h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
100
+ h.update(dtype_name.encode("utf-8"))
101
+ h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
102
+ if not sample_block.flags["C_CONTIGUOUS"]:
103
+ sample_block = np.ascontiguousarray(sample_block)
104
+ h.update(sample_block.view(np.uint8).tobytes())
105
+ return h.hexdigest()
106
+
107
+
108
+ class Lasso(BaseEstimator):
109
+ """
110
+ Lasso regression (L1 regularization) with GPU acceleration
111
+ and full statistical inference.
112
+
113
+ CPU solver supports multiple algorithms (coordinate descent by default, and FISTA when cpu_solver='fista').
114
+ GPU solver supports multiple algorithms via `solver` (e.g. FISTA / ADMM).
115
+
116
+ Parameters
117
+ ----------
118
+ alpha : float, default=1.0
119
+ Regularization strength. Larger values specify stronger regularization.
120
+ Must be non-negative.
121
+ fit_intercept : bool, default=True
122
+ Whether to calculate the intercept.
123
+ max_iter : int, default=1000
124
+ Maximum number of iterations for coordinate descent.
125
+ tol : float, default=1e-4
126
+ Tolerance for convergence.
127
+ device : str or Device, default='auto'
128
+ Computation device: 'cpu', 'cuda', or 'auto'.
129
+ cpu_solver : str, default='coordinate_descent'
130
+ CPU optimization algorithm: 'coordinate_descent' or 'fista'.
131
+ GPU uses the `solver` parameter instead.
132
+
133
+ Attributes
134
+ ----------
135
+ coef_ : ndarray of shape (n_features,)
136
+ Estimated coefficients.
137
+ intercept_ : float
138
+ Independent term.
139
+ n_iter_ : int
140
+ Number of iterations run.
141
+ """
142
+
143
+ # Internal cache for CuPy fused kernels (populated on first GPU use)
144
+ _cupy_fused_kernels = None
145
+
146
+ def __init__(
147
+ self,
148
+ alpha: float = 1.0,
149
+ fit_intercept: bool = True,
150
+ max_iter: int = 1000,
151
+ tol: float = 1e-4,
152
+ stopping: str = "coef_delta",
153
+ inference_method: str = "cpu_ols_inference",
154
+ n_bootstrap: int = 200,
155
+ bootstrap_random_state: Optional[int] = None,
156
+ enable_simultaneous_inference: bool = False,
157
+ simultaneous_method: str = "maxz_bootstrap",
158
+ simultaneous_alpha: float = 0.05,
159
+ simultaneous_n_bootstrap: int = 1000,
160
+ simultaneous_random_state: Optional[int] = None,
161
+ simultaneous_include_intercept: bool = False,
162
+ device: Union[str, Device] = Device.AUTO,
163
+ n_jobs: Optional[int] = None,
164
+ compute_inference: bool = True,
165
+ solver: str = "fista",
166
+ cpu_solver: str = "coordinate_descent",
167
+ lipschitz_L: Optional[float] = None,
168
+ admm_rho: float = 1.0,
169
+ gpu_memory_cleanup: bool = False,
170
+ ):
171
+ super().__init__(device=device, n_jobs=n_jobs)
172
+ self.alpha = alpha
173
+ self.fit_intercept = fit_intercept
174
+ self.max_iter = max_iter
175
+ self.tol = tol
176
+ self.stopping = stopping.lower()
177
+ self.inference_method = inference_method.lower()
178
+ # Semantic rename with backwards-compatible aliases.
179
+ # - "naive_ols" previously meant CPU-sided t-distribution inference.
180
+ # - "gpu_naive_ols" previously meant GPU-sided t-distribution inference
181
+ # with minimal residual/design transfers.
182
+ alias_map = {
183
+ "naive_ols": "cpu_ols_inference",
184
+ "gpu_naive_ols": "gpu_ols_inference",
185
+ }
186
+ self.inference_method = alias_map.get(self.inference_method, self.inference_method)
187
+ self.n_bootstrap = int(n_bootstrap)
188
+ self.bootstrap_random_state = bootstrap_random_state
189
+ self.enable_simultaneous_inference = bool(enable_simultaneous_inference)
190
+ self.simultaneous_method = str(simultaneous_method).lower()
191
+ self.simultaneous_alpha = float(simultaneous_alpha)
192
+ self.simultaneous_n_bootstrap = int(simultaneous_n_bootstrap)
193
+ self.simultaneous_random_state = simultaneous_random_state
194
+ self.simultaneous_include_intercept = bool(simultaneous_include_intercept)
195
+ self.compute_inference = compute_inference
196
+ self.solver = solver.lower()
197
+ self.cpu_solver = cpu_solver.lower()
198
+ self.lipschitz_L = lipschitz_L
199
+ self.admm_rho = admm_rho
200
+ self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
201
+ self.coef_ = None
202
+ self.intercept_ = None
203
+ self.n_iter_ = 0
204
+
205
+ # Internal storage for inference
206
+ self._X_design = None
207
+ self._y = None
208
+ self._resid = None
209
+ self._scale = None
210
+ self._nobs = None
211
+ self._df_resid = None
212
+ self._params = None
213
+ self._bse = None
214
+ self._tvalues = None
215
+ self._pvalues = None
216
+ self._conf_int = None
217
+ self._conf_int_simultaneous = None
218
+ self._simultaneous_enabled = False
219
+ self._simultaneous_method = None
220
+ self._simultaneous_alpha = None
221
+ self._simultaneous_n_bootstrap = None
222
+ self._simultaneous_critical_value = None
223
+ self._simultaneous_target_mask = None
224
+ self._debiased_M_cpu = None
225
+ self._inference_cautions = []
226
+
227
+ def fit(self, X, y, sample_weight=None):
228
+ """Fit Lasso regression model using coordinate descent."""
229
+ self._validate_simultaneous_config()
230
+ self._reset_simultaneous_outputs()
231
+ device = self._get_compute_device()
232
+
233
+ # Get backend - support explicit torch backend selection
234
+ backend = self._get_backend(backend="auto")
235
+ backend_name = backend.name
236
+
237
+ if device == Device.CPU and self.inference_method == "gpu_ols_inference":
238
+ raise ValueError(
239
+ "inference_method='gpu_ols_inference' requires device='cuda' or "
240
+ "device='torch'. Use inference_method='cpu_ols_inference' on CPU."
241
+ )
242
+ if device in (Device.CUDA, Device.TORCH) and self.inference_method == "cpu_ols_inference":
243
+ self.inference_method = "gpu_ols_inference"
244
+ if device == Device.CPU:
245
+ self._y = np.asarray(y)
246
+ else:
247
+ # GPU path: avoid host copies unless CPU-side inference needs y.
248
+ if (not self.compute_inference) or self.inference_method in (
249
+ "gpu_ols_inference",
250
+ "debiased",
251
+ ):
252
+ self._y = None
253
+ else:
254
+ # y may already be a CuPy array; use safe conversion.
255
+ self._y = self._to_numpy(y)
256
+
257
+ if (
258
+ self.compute_inference
259
+ and device in (Device.CUDA, Device.TORCH)
260
+ and self.inference_method not in ("gpu_ols_inference", "debiased")
261
+ ):
262
+ raise NotImplementedError(
263
+ f"Lasso inference_method='{self.inference_method}' is not implemented "
264
+ f"for device='{device.value}' without CPU fallback."
265
+ )
266
+
267
+ X_arr = self._to_array(X, backend=backend_name)
268
+ y_arr = self._to_array(y, backend=backend_name)
269
+
270
+ # Route to appropriate backend
271
+ if backend_name == "torch":
272
+ self._fit_torch(X_arr, y_arr, sample_weight)
273
+ elif device == Device.CUDA:
274
+ self._fit_gpu(X_arr, y_arr, sample_weight)
275
+ else:
276
+ self._fit_cpu(X_arr, y_arr, sample_weight)
277
+
278
+ _skip_post_fit = {"gpu_ols_inference"}
279
+ if device == Device.CUDA and self.inference_method == "debiased":
280
+ _skip_post_fit.add("debiased")
281
+ if backend_name == "torch" and self.inference_method == "debiased":
282
+ _skip_post_fit.add("debiased")
283
+ if self.compute_inference and self.inference_method not in _skip_post_fit:
284
+ self._compute_inference()
285
+ if self.enable_simultaneous_inference:
286
+ self._compute_simultaneous_inference()
287
+ self._inference_cautions = self._build_inference_cautions()
288
+ for msg in self._inference_cautions:
289
+ warnings.warn(msg, UserWarning, stacklevel=2)
290
+ self._fitted = True
291
+ return self
292
+
293
+ def _validate_simultaneous_config(self):
294
+ if not self.enable_simultaneous_inference:
295
+ return
296
+ if not self.compute_inference:
297
+ raise ValueError(
298
+ "enable_simultaneous_inference=True requires compute_inference=True."
299
+ )
300
+ if self.inference_method != "debiased":
301
+ raise ValueError(
302
+ "enable_simultaneous_inference=True currently requires "
303
+ "inference_method='debiased'."
304
+ )
305
+ if self.simultaneous_method != "maxz_bootstrap":
306
+ raise ValueError(
307
+ "simultaneous_method must be 'maxz_bootstrap'."
308
+ )
309
+ if not (0.0 < self.simultaneous_alpha < 1.0):
310
+ raise ValueError("simultaneous_alpha must be in (0, 1).")
311
+ if self.simultaneous_n_bootstrap <= 0:
312
+ raise ValueError("simultaneous_n_bootstrap must be a positive integer.")
313
+
314
+ def _reset_simultaneous_outputs(self):
315
+ self._conf_int_simultaneous = None
316
+ self._simultaneous_enabled = False
317
+ self._simultaneous_method = None
318
+ self._simultaneous_alpha = None
319
+ self._simultaneous_n_bootstrap = None
320
+ self._simultaneous_critical_value = None
321
+ self._simultaneous_target_mask = None
322
+ self._debiased_M_cpu = None
323
+
324
+ def _build_inference_cautions(self):
325
+ cautions = []
326
+ if not self.compute_inference:
327
+ return cautions
328
+
329
+ if self.inference_method in ("cpu_ols_inference", "gpu_ols_inference"):
330
+ cautions.append(
331
+ "Lasso OLS-style post-selection intervals are heuristic and do not "
332
+ "provide valid selective-inference confidence coverage."
333
+ )
334
+
335
+ if self.inference_method == "debiased":
336
+ cautions.append(
337
+ "Debiased Lasso currently reports per-coefficient (marginal) confidence "
338
+ "intervals only; joint/multiple-testing coverage is not guaranteed."
339
+ )
340
+ if self._simultaneous_enabled:
341
+ target_txt = (
342
+ "including intercept"
343
+ if (self.fit_intercept and self.simultaneous_include_intercept)
344
+ else "excluding intercept"
345
+ )
346
+ cautions.append(
347
+ "Simultaneous inference enabled via maxz_bootstrap with joint coverage "
348
+ f"target set {target_txt}."
349
+ )
350
+ if self.fit_intercept and self.simultaneous_include_intercept:
351
+ cautions.append(
352
+ "Intercept is included using the same max-|Z| critical value "
353
+ "calibrated on feature coefficients."
354
+ )
355
+
356
+ return cautions
357
+
358
+ @staticmethod
359
+ def _get_cupy_fused_kernels():
360
+ """
361
+ Get cached CuPy fused kernels for Lasso FISTA solver.
362
+
363
+ Fused kernels combine multiple elementwise operations into a single
364
+ kernel launch, reducing GPU kernel launch overhead. This is especially
365
+ beneficial for small-to-medium data sizes (n < 2000, p < 100).
366
+
367
+ Returns
368
+ -------
369
+ dict or None
370
+ Dictionary of fused kernels, or None if CuPy is not available.
371
+ """
372
+ # Check cache first (class-level cache shared across all instances)
373
+ if Lasso._cupy_fused_kernels is not None:
374
+ return Lasso._cupy_fused_kernels
375
+
376
+ try:
377
+ import cupy as cp
378
+ except ImportError:
379
+ return None
380
+
381
+ # Fused soft thresholding: sign(x) * max(|x| - gamma, 0)
382
+ @cp.fuse()
383
+ def _soft_threshold_fused(x, gamma):
384
+ """Fused soft thresholding operator."""
385
+ abs_x = abs(x)
386
+ return (x > 0) * (abs_x > gamma) * (abs_x - gamma) - (x < 0) * (abs_x > gamma) * (abs_x - gamma)
387
+
388
+ # Fused FISTA momentum update: coef + beta * (coef - coef_old)
389
+ @cp.fuse()
390
+ def _fista_momentum_fused(coef, coef_old, beta):
391
+ """Fused FISTA momentum update."""
392
+ return coef + beta * (coef - coef_old)
393
+
394
+ # Fused KKT violation check: max(|grad| - alpha, 0)
395
+ @cp.fuse()
396
+ def _kkt_violation_fused(grad, alpha):
397
+ """Fused KKT violation computation."""
398
+ abs_grad = abs(grad)
399
+ diff = abs_grad - alpha
400
+ return (diff > 0) * diff
401
+
402
+ # Custom ElementwiseKernel for soft thresholding
403
+ SOFT_THRESHOLD_KERNEL = cp.ElementwiseKernel(
404
+ 'float64 x, float64 gamma',
405
+ 'float64 y',
406
+ '''
407
+ double abs_x = abs(x);
408
+ if (abs_x > gamma) {
409
+ y = (x > 0 ? 1.0 : -1.0) * (abs_x - gamma);
410
+ } else {
411
+ y = 0.0;
412
+ }
413
+ ''',
414
+ 'lasso_soft_threshold'
415
+ )
416
+
417
+ # Custom ElementwiseKernel for absolute delta (convergence check)
418
+ ABS_DELTA_KERNEL = cp.ElementwiseKernel(
419
+ 'float64 a, float64 b',
420
+ 'float64 y',
421
+ '''
422
+ double diff = a - b;
423
+ y = (diff > 0 ? diff : -diff);
424
+ ''',
425
+ 'lasso_abs_delta'
426
+ )
427
+
428
+ # Cache and return
429
+ Lasso._cupy_fused_kernels = {
430
+ 'soft_threshold': _soft_threshold_fused,
431
+ 'fista_momentum': _fista_momentum_fused,
432
+ 'kkt_violation': _kkt_violation_fused,
433
+ 'elementwise_kernel': SOFT_THRESHOLD_KERNEL,
434
+ 'abs_delta_kernel': ABS_DELTA_KERNEL,
435
+ }
436
+
437
+ return Lasso._cupy_fused_kernels
438
+
439
+ def _soft_threshold(self, x, gamma):
440
+ """Soft thresholding operator: S(x, gamma) = sign(x) * max(|x| - gamma, 0)."""
441
+ return np.sign(x) * np.maximum(np.abs(x) - gamma, 0)
442
+
443
+ def _fit_cpu(self, X, y, sample_weight=None):
444
+ """Fit using CPU (coordinate descent or FISTA)."""
445
+ X = np.asarray(X)
446
+ y = np.asarray(y)
447
+
448
+ n_samples, n_features = X.shape
449
+ self._nobs = n_samples
450
+
451
+ if sample_weight is not None:
452
+ sample_weight = np.asarray(sample_weight)
453
+ sqrt_sw = np.sqrt(sample_weight)
454
+ X = X * sqrt_sw[:, np.newaxis]
455
+ y = y * sqrt_sw
456
+
457
+ if self.fit_intercept:
458
+ X_mean = np.mean(X, axis=0)
459
+ y_mean = np.mean(y)
460
+ X_centered = X - X_mean
461
+ y_centered = y - y_mean
462
+ else:
463
+ X_centered = X
464
+ y_mean = 0.0
465
+ y_centered = y
466
+
467
+ if y.ndim == 1:
468
+ y_centered = y_centered.reshape(-1, 1)
469
+
470
+ Xty = X_centered.T @ y_centered.flatten()
471
+ XtX = X_centered.T @ X_centered
472
+
473
+ coef = np.zeros(n_features)
474
+
475
+ if self.cpu_solver in ("fista",):
476
+ # Proximal gradient / FISTA for L1-regularized least squares:
477
+ # minimize (1/(2n)) * ||y - Xw||^2 + alpha * ||w||_1
478
+ # Uses the same stopping criterion as coordinate descent in this codebase:
479
+ # sum(abs(coef - coef_old)) < tol
480
+
481
+ if self.lipschitz_L is not None:
482
+ L = float(self.lipschitz_L)
483
+ else:
484
+ L_frob = float(np.sum(X_centered**2) / n_samples)
485
+ try:
486
+ eigvals = np.linalg.eigvalsh(XtX)
487
+ L = float(eigvals[-1] / n_samples)
488
+ except Exception:
489
+ L = L_frob
490
+
491
+ if L <= 0:
492
+ coef = np.zeros(n_features)
493
+ self.n_iter_ = 0
494
+ else:
495
+ step = 1.0 / L
496
+ thresh = self.alpha * step
497
+
498
+ # FISTA variables
499
+ y_k = coef.copy()
500
+ t_k = 1.0
501
+
502
+ for iteration in range(self.max_iter):
503
+ coef_old = coef.copy()
504
+
505
+ # grad = (XtX @ y_k - Xty) / n
506
+ grad = (XtX @ y_k - Xty) / n_samples
507
+
508
+ coef = self._soft_threshold(y_k - step * grad, thresh)
509
+
510
+ # Momentum update
511
+ t_new = (1.0 + np.sqrt(1.0 + 4.0 * (t_k**2))) / 2.0
512
+ beta = (t_k - 1.0) / t_new
513
+ y_k = coef + beta * (coef - coef_old)
514
+ t_k = t_new
515
+
516
+ if self.stopping == "kkt":
517
+ # KKT violation for Lasso:
518
+ # grad_sse = (XtX @ w - Xty) / n
519
+ # optimality: |grad_sse_j| <= alpha when w_j == 0
520
+ # violation measure: max_j max(|grad_sse_j| - alpha, 0)
521
+ grad_sse = (XtX @ coef - Xty) / n_samples
522
+ violation = np.max(np.maximum(np.abs(grad_sse) - self.alpha, 0.0))
523
+ if violation < self.tol:
524
+ self.n_iter_ = iteration + 1
525
+ break
526
+ else:
527
+ # Legacy stopping: coefficient delta
528
+ if np.sum(np.abs(coef - coef_old)) < self.tol:
529
+ self.n_iter_ = iteration + 1
530
+ break
531
+ else:
532
+ self.n_iter_ = self.max_iter
533
+
534
+ else:
535
+ # Coordinate descent (legacy CPU path)
536
+ # Precompute squared norms for each feature
537
+ X_sq_norms = np.diag(XtX)
538
+
539
+ for iteration in range(self.max_iter):
540
+ coef_old = coef.copy()
541
+
542
+ for j in range(n_features):
543
+ # Compute partial residual
544
+ rho_j = Xty[j] - np.dot(XtX[j, :], coef) + XtX[j, j] * coef[j]
545
+
546
+ # Update coefficient with soft thresholding
547
+ if X_sq_norms[j] > 1e-10:
548
+ coef[j] = self._soft_threshold(rho_j, self.alpha * n_samples) / X_sq_norms[j]
549
+ else:
550
+ coef[j] = 0.0
551
+
552
+ # Check convergence
553
+ if self.stopping == "kkt":
554
+ grad_sse = (XtX @ coef - Xty) / n_samples
555
+ violation = np.max(np.maximum(np.abs(grad_sse) - self.alpha, 0.0))
556
+ if violation < self.tol:
557
+ self.n_iter_ = iteration + 1
558
+ break
559
+ else:
560
+ if np.sum(np.abs(coef - coef_old)) < self.tol:
561
+ self.n_iter_ = iteration + 1
562
+ break
563
+ else:
564
+ self.n_iter_ = self.max_iter
565
+
566
+ # Compute intercept
567
+ if self.fit_intercept:
568
+ self.intercept_ = float(y_mean - X_mean @ coef)
569
+ self.coef_ = coef
570
+ self._params = np.concatenate([[self.intercept_], self.coef_])
571
+ else:
572
+ self.intercept_ = 0.0
573
+ self.coef_ = coef
574
+ self._params = coef.copy()
575
+ self._df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
576
+ if self.compute_inference:
577
+ if self.fit_intercept:
578
+ self._X_design = np.column_stack(
579
+ [np.ones(n_samples, dtype=X.dtype), X]
580
+ )
581
+ else:
582
+ self._X_design = X.copy()
583
+
584
+ y_pred = self._X_design @ self._params
585
+ self._resid = self._y - y_pred
586
+
587
+ if self._df_resid > 0:
588
+ self._scale = np.sum(self._resid ** 2) / self._df_resid
589
+ else:
590
+ self._scale = np.nan
591
+ else:
592
+ self._X_design = None
593
+ self._resid = None
594
+ self._scale = np.nan
595
+
596
+ def _soft_threshold_cupy(self, x, gamma):
597
+ """Soft thresholding operator for CuPy arrays.
598
+
599
+ Uses fused kernel when available for improved performance on
600
+ small-to-medium data sizes.
601
+ """
602
+ import cupy as cp
603
+
604
+ # Try to use fused kernel for better performance
605
+ fused = self._get_cupy_fused_kernels()
606
+ if fused is not None:
607
+ # Use ElementwiseKernel for best performance
608
+ return fused['elementwise_kernel'](x, gamma)
609
+
610
+ # Fallback to standard implementation
611
+ return cp.sign(x) * cp.maximum(cp.abs(x) - gamma, 0)
612
+
613
+ def _cleanup_cuda_memory(self):
614
+ """
615
+ Best-effort CUDA memory pool cleanup.
616
+
617
+ CuPy caches freed blocks in its memory pool for speed. Enable
618
+ `gpu_memory_cleanup=True` to return cached blocks after fit when
619
+ VRAM pressure is more important than repeated-fit throughput.
620
+ """
621
+ if not self.gpu_memory_cleanup:
622
+ return
623
+ try:
624
+ import cupy as cp
625
+ cp.get_default_memory_pool().free_all_blocks()
626
+ cp.get_default_pinned_memory_pool().free_all_blocks()
627
+ except Exception:
628
+ pass
629
+
630
+ def _fit_gpu(self, X, y, sample_weight=None):
631
+ """Fit using GPU solver."""
632
+ import cupy as cp
633
+ from statgpu.backends._gpu_inference_cupy import compute_r2_gpu
634
+
635
+ if self.solver not in ("fista", "admm"):
636
+ raise ValueError("solver must be one of: 'fista', 'admm'")
637
+
638
+ if self.solver == "admm":
639
+ return self._fit_gpu_admm(X, y, sample_weight=sample_weight)
640
+
641
+ # Default: FISTA
642
+
643
+ n_samples, n_features = X.shape
644
+ self._nobs = n_samples
645
+
646
+ # Ensure CuPy arrays
647
+ X = cp.asarray(X)
648
+ y = cp.asarray(y)
649
+
650
+ if sample_weight is not None:
651
+ sample_weight = cp.asarray(sample_weight)
652
+ sqrt_sw = cp.sqrt(sample_weight)
653
+ X = X * sqrt_sw[:, cp.newaxis]
654
+ y = y * sqrt_sw
655
+
656
+ # Ensure vector y on GPU
657
+ y = y.reshape(-1)
658
+
659
+ # Center X/y when fitting intercept to match sklearn Lasso convention.
660
+ if self.fit_intercept:
661
+ X_mean = cp.mean(X, axis=0)
662
+ y_mean = cp.mean(y)
663
+ X_centered = X - X_mean
664
+ y_centered = y - y_mean
665
+ else:
666
+ X_centered = X
667
+ y_mean = cp.array(0.0, dtype=X.dtype)
668
+ y_centered = y
669
+
670
+ # Precompute XtX / Xty for FISTA gradient: grad(w) = (XtX @ w - Xty) / n
671
+ XtX = X_centered.T @ X_centered
672
+ Xty = X_centered.T @ y_centered
673
+
674
+ # Lipschitz constant L for grad(w): L = lambda_max(XtX) / n
675
+ # If user provides lipschitz_L, trust it (should be safe for convergence).
676
+ if self.lipschitz_L is not None:
677
+ L = cp.array(float(self.lipschitz_L), dtype=X.dtype)
678
+ else:
679
+ L_frob = cp.sum(X_centered ** 2) / n_samples
680
+ try:
681
+ eigvals = cp.linalg.eigvalsh(XtX)
682
+ L = eigvals[-1] / n_samples
683
+ except Exception:
684
+ L = L_frob
685
+
686
+ if L <= 0:
687
+ # Degenerate case: return all-zero coefficients
688
+ coef = cp.zeros(n_features, dtype=X.dtype)
689
+ self.n_iter_ = 0
690
+ else:
691
+ step = 1.0 / L
692
+ thresh = self.alpha * step
693
+
694
+ # FISTA variables
695
+ coef = cp.zeros(n_features, dtype=X.dtype) # w_k
696
+ y_k = coef.copy() # y_k
697
+ t_k = cp.array(1.0, dtype=X.dtype)
698
+
699
+ # Get fused kernels for optimized FISTA iterations
700
+ fused = self._get_cupy_fused_kernels()
701
+
702
+ for iteration in range(self.max_iter):
703
+ coef_old = coef
704
+
705
+ # Gradient at y_k: (1/n) XtX @ y_k - (1/n) Xty
706
+ grad = (XtX @ y_k - Xty) / n_samples
707
+
708
+ # Prox step for L1
709
+ coef = self._soft_threshold_cupy(y_k - step * grad, thresh)
710
+
711
+ # Momentum update (use fused kernel when available)
712
+ t_new = (1 + cp.sqrt(1 + 4 * (t_k ** 2))) / 2
713
+ beta = (t_k - 1) / t_new
714
+ if fused is not None:
715
+ y_k = fused['fista_momentum'](coef, coef_old, beta)
716
+ else:
717
+ y_k = coef + beta * (coef - coef_old)
718
+ t_k = t_new
719
+
720
+ # Convergence test
721
+ if self.stopping == "kkt":
722
+ grad_sse = (XtX @ coef - Xty) / n_samples
723
+ # Use fused KKT violation check when available
724
+ if fused is not None:
725
+ violation = cp.max(fused['kkt_violation'](grad_sse, self.alpha))
726
+ else:
727
+ violation = cp.max(cp.maximum(cp.abs(grad_sse) - self.alpha, 0.0))
728
+ if violation < self.tol:
729
+ self.n_iter_ = iteration + 1
730
+ break
731
+ else:
732
+ # Legacy stopping: coefficient delta (fast but not guaranteed objective optimality)
733
+ # Use fused delta kernel when available
734
+ if fused is not None and 'abs_delta_kernel' in fused:
735
+ delta = cp.sum(fused['abs_delta_kernel'](coef, coef_old))
736
+ else:
737
+ delta = cp.sum(cp.abs(coef - coef_old))
738
+ if delta < self.tol:
739
+ self.n_iter_ = iteration + 1
740
+ break
741
+ else:
742
+ self.n_iter_ = self.max_iter
743
+
744
+ # Build full coefficients and (optionally) residuals for inference/R^2
745
+ if self.fit_intercept:
746
+ intercept_gpu = y_mean - X_mean @ coef
747
+ coef_full = cp.concatenate([intercept_gpu.reshape(1), coef])
748
+ else:
749
+ coef_full = coef
750
+
751
+ # Always transfer coefficients; remaining transfers depend on compute_inference.
752
+ coef_full_np = coef_full.get()
753
+
754
+ if self.fit_intercept:
755
+ self.intercept_ = float(coef_full_np[0])
756
+ self.coef_ = coef_full_np[1:]
757
+ self._params = coef_full_np
758
+ else:
759
+ self.intercept_ = 0.0
760
+ self.coef_ = coef_full_np
761
+ self._params = coef_full_np
762
+
763
+ df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
764
+ self._df_resid = df_resid
765
+
766
+ # Inference/diagnostics require residuals and design matrix.
767
+ if self.compute_inference:
768
+ # Only build the design matrix when we need residuals/inference.
769
+ if self.fit_intercept:
770
+ X_design = cp.concatenate(
771
+ [cp.ones((n_samples, 1), dtype=X.dtype), X], axis=1
772
+ )
773
+ else:
774
+ X_design = X
775
+
776
+ y_pred = X_design @ coef_full
777
+ resid = y - y_pred
778
+
779
+ if df_resid > 0:
780
+ scale = cp.sum(resid ** 2) / df_resid
781
+ self._scale = float(scale.get()) if not cp.isnan(scale) else np.nan
782
+ else:
783
+ self._scale = np.nan
784
+ scale = cp.nan
785
+
786
+ if self.inference_method == "gpu_ols_inference":
787
+ # Compute inference fully on GPU, then transfer only small vectors.
788
+ XtX = X_design.T @ X_design
789
+ try:
790
+ XtX_inv = cp.linalg.inv(XtX)
791
+ except Exception:
792
+ XtX_inv = cp.linalg.pinv(XtX)
793
+
794
+ bse_gpu = cp.sqrt(scale * cp.diag(XtX_inv))
795
+
796
+ # Inference vectors on GPU to avoid scipy/cpu cdf/ppf.
797
+ params_gpu = coef_full # includes intercept when fit_intercept=True
798
+ tvalues_gpu = params_gpu / (bse_gpu + 1e-30)
799
+ # Two-tailed p-values from the Student-t survival function should
800
+ # already lie in [0, 1]. We still clamp at 1.0 as a defensive
801
+ # safeguard against tiny floating-point overshoots on GPU/backends.
802
+ pvalues_gpu = cp.minimum(1.0, 2.0 * t.sf(cp.abs(tvalues_gpu), df=df_resid))
803
+
804
+ alpha = 0.05 # two-tailed for 95% CI
805
+ t_crit_gpu = t.ppf(1.0 - alpha / 2.0, df=df_resid)
806
+ margin_gpu = t_crit_gpu * bse_gpu
807
+ conf_int_gpu = cp.stack([params_gpu - margin_gpu, params_gpu + margin_gpu], axis=1)
808
+
809
+ # Transfer only the small inference vectors back to CPU.
810
+ self._bse = cp.asnumpy(bse_gpu)
811
+ self._tvalues = cp.asnumpy(tvalues_gpu)
812
+ self._pvalues = cp.asnumpy(pvalues_gpu)
813
+ self._conf_int = cp.asnumpy(conf_int_gpu)
814
+
815
+ # R^2 / keep diagnostics consistent without transferring residuals.
816
+ y_mean_gpu = cp.mean(y)
817
+ ss_tot = cp.sum((y - y_mean_gpu) ** 2)
818
+ ss_res = cp.sum(resid ** 2)
819
+ self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
820
+
821
+ self._resid = None
822
+ self._X_design = None
823
+ elif self.inference_method == "debiased":
824
+ self._compute_inference_debiased_gpu(X, y, coef)
825
+
826
+ y_mean_gpu = cp.mean(y)
827
+ ss_tot = cp.sum((y - y_mean_gpu) ** 2)
828
+ ss_res = cp.sum(resid ** 2)
829
+ self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
830
+
831
+ self._resid = None
832
+ self._X_design = None
833
+ else:
834
+ # Default: transfer residuals and design to CPU.
835
+ self._resid = resid.get()
836
+ self._X_design = X_design.get()
837
+
838
+ else:
839
+ # Strict GPU mode: avoid large residual/host design transfers.
840
+ self._scale = np.nan
841
+ self._resid = None
842
+ self._X_design = None
843
+ # R^2 is optional; keep behavior as None when no residuals are available.
844
+ self._rsquared_gpu = None
845
+
846
+ # Drop large temporaries early (before optional pool cleanup).
847
+ try:
848
+ del X_design
849
+ except Exception:
850
+ pass
851
+ try:
852
+ del resid
853
+ except Exception:
854
+ pass
855
+ try:
856
+ del XtX
857
+ except Exception:
858
+ pass
859
+ try:
860
+ del Xty
861
+ except Exception:
862
+ pass
863
+ try:
864
+ del X_centered
865
+ except Exception:
866
+ pass
867
+ try:
868
+ del y_centered
869
+ except Exception:
870
+ pass
871
+ try:
872
+ del y_pred
873
+ except Exception:
874
+ pass
875
+ try:
876
+ del coef_full
877
+ except Exception:
878
+ pass
879
+ self._cleanup_cuda_memory()
880
+
881
+ def _cleanup_torch_memory(self):
882
+ """Best-effort Torch CUDA memory cleanup."""
883
+ if not self.gpu_memory_cleanup:
884
+ return
885
+ try:
886
+ import torch
887
+ if torch.cuda.is_available():
888
+ torch.cuda.empty_cache()
889
+ torch.cuda.synchronize()
890
+ except Exception:
891
+ pass
892
+
893
+ def _matrix_fingerprint_torch(self, X: "torch.Tensor") -> str:
894
+ """Generate a fingerprint key for caching debiased M matrix (Torch version)."""
895
+ import torch
896
+ n, p = X.shape
897
+ r = min(24, n)
898
+ c = min(24, p)
899
+ sample = X[:r, :c].contiguous()
900
+ h = hashlib.sha1()
901
+ h.update(str((n, p, str(X.dtype))).encode("utf-8"))
902
+ h.update(sample.cpu().numpy().tobytes())
903
+ return h.hexdigest()
904
+
905
+ def _solve_lasso_path_torch_fista_multi_fold_from_gram(
906
+ self,
907
+ XtX_batch,
908
+ Xty_batch,
909
+ *,
910
+ n_samples_vec,
911
+ alphas_desc,
912
+ max_iter,
913
+ tol,
914
+ stopping,
915
+ lipschitz_L=None,
916
+ check_every=8,
917
+ ):
918
+ """Solve descending-alpha Lasso paths for all folds together on Torch GPU."""
919
+ import torch
920
+
921
+ n_folds = int(XtX_batch.shape[0])
922
+ n_features = int(XtX_batch.shape[1])
923
+ n_alphas = int(alphas_desc.shape[0])
924
+ dtype = XtX_batch.dtype
925
+ device = XtX_batch.device
926
+
927
+ coefs = torch.zeros((n_folds, n_features, n_alphas), dtype=dtype, device=device)
928
+ yk = coefs.clone()
929
+ tk = torch.ones((n_folds, n_alphas), dtype=dtype, device=device)
930
+ n_iters = torch.zeros((n_folds, n_alphas), dtype=torch.int32, device=device)
931
+
932
+ n_vec = torch.as_tensor(n_samples_vec, dtype=dtype, device=device).reshape(-1)
933
+ if n_vec.size != n_folds:
934
+ raise ValueError("n_samples_vec must have one entry per fold")
935
+
936
+ if lipschitz_L is not None:
937
+ L = torch.full((n_folds,), float(lipschitz_L), dtype=dtype, device=device)
938
+ else:
939
+ try:
940
+ eigvals = torch.linalg.eigvalsh(XtX_batch)
941
+ L = eigvals[:, -1] / n_vec
942
+ except Exception:
943
+ row_sum_bound = torch.max(torch.sum(torch.abs(XtX_batch), dim=2), dim=1)[0] / n_vec
944
+ L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=dtype, device=device))
945
+
946
+ step = 1.0 / L.reshape(n_folds, 1, 1)
947
+ alpha_gpu = torch.as_tensor(np.asarray(alphas_desc, dtype=np.float64), dtype=dtype, device=device).reshape(1, 1, n_alphas)
948
+ thresholds = alpha_gpu * step
949
+
950
+ Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
951
+ n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
952
+ stopping_name = str(stopping).lower()
953
+ check_every = max(1, int(check_every))
954
+
955
+ active_gpu = torch.ones((n_folds, n_alphas), dtype=torch.bool, device=device)
956
+ active_count = int(n_folds * n_alphas)
957
+
958
+ for iteration in range(int(max_iter)):
959
+ if active_count == 0:
960
+ break
961
+
962
+ active_expanded = active_gpu[:, None, :]
963
+
964
+ coef_old = coefs.clone()
965
+ grad = (torch.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
966
+ coef_candidate = torch.sign(yk - step * grad) * torch.maximum(torch.abs(yk - step * grad) - thresholds, torch.tensor(0.0, dtype=dtype, device=device))
967
+ coefs = torch.where(active_expanded, coef_candidate, coefs)
968
+
969
+ t_old = tk
970
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
971
+ beta = (t_old - 1.0) / t_new
972
+ y_candidate = coefs + beta[:, None, :] * (coefs - coef_old)
973
+ yk = torch.where(active_expanded, y_candidate, yk)
974
+ tk = torch.where(active_gpu, t_new, tk)
975
+
976
+ active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
977
+ check_every_eff = max(check_every, 1)
978
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
979
+ if not should_check:
980
+ continue
981
+
982
+ if stopping_name == "kkt":
983
+ grad_sse = (torch.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
984
+ violation = torch.max(torch.maximum(torch.abs(grad_sse) - alpha_gpu, torch.tensor(0.0, dtype=dtype, device=device)), dim=1)[0]
985
+ converged_local_gpu = violation < float(tol)
986
+ else:
987
+ delta = torch.sum(torch.abs(coefs - coef_old), dim=1)
988
+ converged_local_gpu = delta < float(tol)
989
+
990
+ newly_done_gpu = active_gpu & converged_local_gpu
991
+ done_count = int(torch.count_nonzero(newly_done_gpu).item())
992
+ if done_count == 0:
993
+ continue
994
+
995
+ n_iters[newly_done_gpu] = int(iteration) + 1
996
+ yk = torch.where(newly_done_gpu[:, None, :], coefs, yk)
997
+ active_gpu = active_gpu & (~converged_local_gpu)
998
+ active_count -= done_count
999
+
1000
+ return coefs.transpose(1, 2), n_iters.cpu().numpy()
1001
+
1002
+ def _compute_inference_debiased_torch(self, X_torch, y_torch, coef_torch):
1003
+ """Torch GPU path for debiased Lasso inference.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ X_torch : torch.Tensor, shape (n, p)
1008
+ Raw feature matrix on Torch GPU (no intercept column).
1009
+ y_torch : torch.Tensor, shape (n,)
1010
+ Response on Torch GPU.
1011
+ coef_torch : torch.Tensor, shape (p,)
1012
+ Lasso coefficients on Torch GPU (no intercept).
1013
+ """
1014
+ import torch
1015
+ from statgpu.inference._distributions_backend import norm
1016
+
1017
+ n, p = X_torch.shape
1018
+ dtype = torch.float64
1019
+ device = X_torch.device
1020
+
1021
+ # Ensure correct dtype
1022
+ if X_torch.dtype != dtype:
1023
+ X_torch = X_torch.to(dtype)
1024
+ if y_torch.dtype != dtype:
1025
+ y_torch = y_torch.to(dtype)
1026
+ if coef_torch.dtype != dtype:
1027
+ coef_torch = coef_torch.to(dtype)
1028
+
1029
+ # Compute Sigma_hat = X'X / n
1030
+ Sigma_hat = X_torch.T @ X_torch / n
1031
+
1032
+ # Compute Lasso residuals
1033
+ resid_lasso = y_torch - X_torch @ coef_torch
1034
+ if self.fit_intercept:
1035
+ resid_lasso = resid_lasso - torch.mean(y_torch) + torch.mean(X_torch, dim=0) @ coef_torch
1036
+
1037
+ # Estimate noise variance sigma^2
1038
+ s_hat = torch.sum(torch.abs(coef_torch) > 0).to(dtype)
1039
+ denom = torch.maximum(torch.tensor(1.0, dtype=dtype, device=device), torch.tensor(float(n), dtype=dtype, device=device) - s_hat)
1040
+ sigma2 = torch.sum(resid_lasso ** 2) / denom
1041
+
1042
+ # Node-wise Lasso for M matrix estimation
1043
+ lam_nw = float(np.sqrt(2.0 * np.log(max(p, 2)) / n))
1044
+ alpha_nw = np.asarray([lam_nw], dtype=np.float64)
1045
+ tiny = 1e-30
1046
+ zero = 0.0
1047
+ one = 1.0
1048
+
1049
+ # Caching for M matrix
1050
+ X_sample = X_torch[: min(24, n), : min(24, p)].cpu().numpy()
1051
+ m_cache_key = _debiased_m_key_from_sample(
1052
+ n=n,
1053
+ p=p,
1054
+ dtype_name=str(dtype),
1055
+ sample_block=X_sample,
1056
+ lam_nw=lam_nw,
1057
+ tol=float(self.tol),
1058
+ )
1059
+ M_cached = _debiased_m_cache_get(m_cache_key)
1060
+
1061
+ if M_cached is not None:
1062
+ M = torch.from_numpy(M_cached).to(dtype).to(device)
1063
+ else:
1064
+ M = torch.zeros((p, p), dtype=dtype, device=device)
1065
+ XtX_full = X_torch.T @ X_torch
1066
+ Sigma_diag = torch.diag(Sigma_hat)
1067
+
1068
+ # Batch node-wise problems for efficiency
1069
+ try:
1070
+ # Estimate available GPU memory for batching
1071
+ if torch.cuda.is_available():
1072
+ free_mem = torch.cuda.mem_get_info(device)[0]
1073
+ bytes_per_fold = max(8, (p - 1) * (p - 1) * 8 * 2)
1074
+ chunk_size = int(max(4, min(64, free_mem // max(bytes_per_fold, 1))))
1075
+ else:
1076
+ chunk_size = 16
1077
+ except Exception:
1078
+ chunk_size = 16
1079
+ chunk_size = max(4, min(int(p), chunk_size))
1080
+
1081
+ for j0 in range(0, p, chunk_size):
1082
+ j1 = min(p, j0 + chunk_size)
1083
+ bsz = j1 - j0
1084
+ j_batch = torch.arange(j0, j1, dtype=torch.int32, device=device)
1085
+
1086
+ # Build "all except j" column index matrix
1087
+ base = torch.arange(p - 1, dtype=torch.int32, device=device).reshape(1, -1)
1088
+ cols_batch = base + (base >= j_batch.reshape(-1, 1))
1089
+
1090
+ # Gather batched Gram/Xty blocks
1091
+ XtX_batch = XtX_full[
1092
+ cols_batch[:, :, None],
1093
+ cols_batch[:, None, :],
1094
+ ]
1095
+ Xty_batch = XtX_full[cols_batch, j_batch.reshape(-1, 1)].reshape(bsz, p - 1)
1096
+
1097
+ # Solve node-wise Lasso problems
1098
+ coefs_batch_desc, _ = self._solve_lasso_path_torch_fista_multi_fold_from_gram(
1099
+ XtX_batch,
1100
+ Xty_batch,
1101
+ n_samples_vec=np.full((bsz,), float(n), dtype=np.float64),
1102
+ alphas_desc=alpha_nw,
1103
+ max_iter=500,
1104
+ tol=1e-5,
1105
+ stopping="coef_delta",
1106
+ lipschitz_L=None,
1107
+ check_every=8,
1108
+ )
1109
+ gamma_batch = torch.from_numpy(np.asarray(coefs_batch_desc[:, 0, :], dtype=np.float64)).to(dtype).to(device)
1110
+
1111
+ # C_j = Sigma_jj - Sigma_{j,-j} gamma_j
1112
+ sigma_j_cols = Sigma_hat[j_batch[:, None], cols_batch]
1113
+ C_batch = Sigma_diag[j_batch] - torch.sum(sigma_j_cols * gamma_batch, dim=1)
1114
+
1115
+ small_c = torch.abs(C_batch) < tiny
1116
+ inv_c = torch.where(small_c, torch.tensor(zero, dtype=dtype, device=device), torch.tensor(one, dtype=dtype, device=device) / C_batch)
1117
+ M[j_batch, j_batch] = torch.where(small_c, torch.tensor(one, dtype=dtype, device=device), inv_c)
1118
+ M[j_batch[:, None], cols_batch] = -gamma_batch * inv_c.reshape(-1, 1)
1119
+
1120
+ # Cleanup
1121
+ del XtX_batch
1122
+ del Xty_batch
1123
+ del coefs_batch_desc
1124
+ del gamma_batch
1125
+ del sigma_j_cols
1126
+
1127
+ _debiased_m_cache_put(m_cache_key, M.cpu().numpy())
1128
+
1129
+ # Compute full residual
1130
+ if self.fit_intercept:
1131
+ y_pred = X_torch @ coef_torch + torch.tensor(self.intercept_, dtype=dtype, device=device)
1132
+ else:
1133
+ y_pred = X_torch @ coef_torch
1134
+ resid_full = y_torch - y_pred
1135
+
1136
+ # Debiased estimate: theta_db = coef + M @ X' @ resid / n
1137
+ theta_db = coef_torch + (M @ X_torch.T @ resid_full) / n
1138
+
1139
+ # Variance estimation: V = M @ Sigma_hat @ M'
1140
+ V = M @ Sigma_hat @ M.T
1141
+ se = torch.sqrt(sigma2 * torch.diag(V) / n)
1142
+
1143
+ # z-statistics and p-values
1144
+ z_stats = theta_db / (se + 1e-30)
1145
+ pvalues = torch.minimum(torch.tensor(1.0, dtype=dtype, device=device), 2.0 * norm.sf(torch.abs(z_stats)))
1146
+
1147
+ # Confidence intervals
1148
+ alpha_ci = 0.05
1149
+ z_crit = norm.ppf(1.0 - alpha_ci / 2.0)
1150
+ ci = torch.stack([theta_db - z_crit * se, theta_db + z_crit * se], dim=1)
1151
+
1152
+ # Handle intercept
1153
+ if self.fit_intercept:
1154
+ X_full = torch.cat([torch.ones((n, 1), dtype=dtype, device=device), X_torch], dim=1)
1155
+ XtX_full = X_full.T @ X_full
1156
+ try:
1157
+ XtX_inv = torch.linalg.inv(XtX_full)
1158
+ except Exception:
1159
+ XtX_inv = torch.linalg.pinv(XtX_full)
1160
+ se_intercept = torch.sqrt(sigma2 * XtX_inv[0, 0])
1161
+ intercept_torch = torch.tensor(self.intercept_, dtype=dtype, device=device)
1162
+ z_intercept = intercept_torch / (se_intercept + 1e-30)
1163
+ p_intercept = torch.minimum(torch.tensor(1.0, dtype=dtype, device=device), 2.0 * norm.sf(torch.abs(z_intercept).reshape(1)))
1164
+ ci_intercept = torch.stack([
1165
+ intercept_torch - z_crit * se_intercept,
1166
+ intercept_torch + z_crit * se_intercept,
1167
+ ]).reshape(1, 2)
1168
+
1169
+ bse_torch = torch.cat([se_intercept.reshape(1), se])
1170
+ tvalues_torch = torch.cat([z_intercept.reshape(1), z_stats])
1171
+ pvalues_torch = torch.cat([p_intercept.reshape(1), pvalues])
1172
+ conf_int_torch = torch.cat([ci_intercept, ci], dim=0)
1173
+ params_torch = torch.cat([intercept_torch.reshape(1), theta_db])
1174
+ else:
1175
+ bse_torch = se
1176
+ tvalues_torch = z_stats
1177
+ pvalues_torch = pvalues
1178
+ conf_int_torch = ci
1179
+ params_torch = theta_db
1180
+
1181
+ # Transfer to CPU
1182
+ self._bse = bse_torch.cpu().numpy()
1183
+ self._tvalues = tvalues_torch.cpu().numpy()
1184
+ self._pvalues = pvalues_torch.cpu().numpy()
1185
+ self._conf_int = conf_int_torch.cpu().numpy()
1186
+ self._params = params_torch.cpu().numpy()
1187
+
1188
+ # Store M matrix for simultaneous inference
1189
+ self._debiased_M_cpu = M.cpu().numpy()
1190
+
1191
+ # Simultaneous inference (max-|Z| bootstrap)
1192
+ if self.enable_simultaneous_inference:
1193
+ self._compute_simultaneous_inference_torch(
1194
+ params_torch, bse_torch, se, M, X_torch, resid_full, n
1195
+ )
1196
+
1197
+ def _compute_simultaneous_inference_torch(
1198
+ self, params_torch, bse_torch, se_feat_torch, M_torch, X_torch, resid_full_torch, n
1199
+ ):
1200
+ """Torch GPU implementation of simultaneous inference via max-|Z| bootstrap."""
1201
+ import torch
1202
+
1203
+ # Get target indices
1204
+ param_target_idx_np = self._get_simultaneous_target_indices(int(params_torch.shape[0]))
1205
+ param_target_idx_torch = torch.as_tensor(param_target_idx_np, dtype=torch.int32, device=params_torch.device)
1206
+
1207
+ if param_target_idx_torch.size == 0:
1208
+ raise RuntimeError("No coefficients selected for simultaneous inference target set.")
1209
+
1210
+ feature_offset = 1 if self.fit_intercept else 0
1211
+ feature_target_torch = param_target_idx_torch - feature_offset
1212
+ feature_target_torch = feature_target_torch[feature_target_torch >= 0]
1213
+
1214
+ if feature_target_torch.size == 0:
1215
+ raise RuntimeError("No feature coefficients selected for simultaneous inference target set.")
1216
+
1217
+ se_target_torch = torch.index_select(se_feat_torch, 0, feature_target_torch)
1218
+ M_target = torch.index_select(M_torch, 0, feature_target_torch)
1219
+
1220
+ B = int(self.simultaneous_n_bootstrap)
1221
+ if self.simultaneous_random_state is not None:
1222
+ torch.manual_seed(self.simultaneous_random_state)
1223
+
1224
+ # Bootstrap in chunks to manage memory
1225
+ try:
1226
+ # Try one-shot computation
1227
+ xi = torch.randn((B, n), dtype=torch.float64, device=X_torch.device)
1228
+ weighted = xi * resid_full_torch.reshape(1, -1)
1229
+ score_target = (weighted @ X_torch) @ M_target.T / float(max(n, 1))
1230
+ z_star_target = score_target / (se_target_torch.reshape(1, -1) + 1e-30)
1231
+ max_stats_torch = torch.max(torch.abs(z_star_target), dim=1)[0]
1232
+ except Exception:
1233
+ # Fallback to chunked computation
1234
+ max_stats_torch = torch.empty((B,), dtype=torch.float64, device=X_torch.device)
1235
+ chunk = min(B, 64)
1236
+ filled = 0
1237
+ while filled < B:
1238
+ bsz = min(chunk, B - filled)
1239
+ xi = torch.randn((bsz, n), dtype=torch.float64, device=X_torch.device)
1240
+ weighted = xi * resid_full_torch.reshape(1, -1)
1241
+ score_target = (weighted @ X_torch) @ M_target.T / float(max(n, 1))
1242
+ z_star_target = score_target / (se_target_torch.reshape(1, -1) + 1e-30)
1243
+ max_stats_torch[filled : filled + bsz] = torch.max(torch.abs(z_star_target), dim=1)[0]
1244
+ filled += bsz
1245
+
1246
+ # Compute critical value
1247
+ critical_torch = torch.quantile(max_stats_torch, 1.0 - float(self.simultaneous_alpha))
1248
+
1249
+ # Build simultaneous confidence intervals
1250
+ conf_sim_torch = conf_int_torch.clone()
1251
+ lower_torch = torch.index_select(params_torch, 0, param_target_idx_torch) - critical_torch * torch.index_select(bse_torch, 0, param_target_idx_torch)
1252
+ upper_torch = torch.index_select(params_torch, 0, param_target_idx_torch) + critical_torch * torch.index_select(bse_torch, 0, param_target_idx_torch)
1253
+ conf_sim_torch[param_target_idx_torch, 0] = lower_torch
1254
+ conf_sim_torch[param_target_idx_torch, 1] = upper_torch
1255
+
1256
+ # Store results
1257
+ target_mask = np.zeros(int(params_torch.shape[0]), dtype=bool)
1258
+ target_mask[param_target_idx_np] = True
1259
+ self._conf_int_simultaneous = conf_sim_torch.cpu().numpy()
1260
+ self._simultaneous_enabled = True
1261
+ self._simultaneous_method = self.simultaneous_method
1262
+ self._simultaneous_alpha = float(self.simultaneous_alpha)
1263
+ self._simultaneous_n_bootstrap = B
1264
+ self._simultaneous_critical_value = float(critical_torch.cpu().numpy())
1265
+ self._simultaneous_target_mask = target_mask
1266
+
1267
+ def _soft_threshold_torch(self, x, gamma):
1268
+ """Soft thresholding operator for Torch tensors."""
1269
+ import torch
1270
+ return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, torch.tensor(0.0, dtype=x.dtype, device=x.device))
1271
+
1272
+ def _fit_torch(self, X, y, sample_weight=None):
1273
+ """Fit using Torch GPU with FISTA solver."""
1274
+ import torch
1275
+ from statgpu.backends._gpu_inference_torch import compute_r2_torch
1276
+
1277
+ if self.solver not in ("fista", "admm"):
1278
+ raise ValueError("Torch backend currently only supports 'fista' solver")
1279
+
1280
+ # For now, only FISTA is implemented for Torch backend
1281
+ if self.solver == "admm":
1282
+ raise NotImplementedError("ADMM solver not yet implemented for Torch backend")
1283
+
1284
+ n_samples, n_features = X.shape
1285
+ self._nobs = n_samples
1286
+
1287
+ # Ensure Torch tensors on GPU
1288
+ if not isinstance(X, torch.Tensor):
1289
+ X = torch.from_numpy(X).to('cuda')
1290
+ if not isinstance(y, torch.Tensor):
1291
+ y = torch.from_numpy(y).to('cuda')
1292
+ if y.dtype != torch.float64:
1293
+ y = y.to(torch.float64)
1294
+ if X.dtype != torch.float64:
1295
+ X = X.to(torch.float64)
1296
+
1297
+ if sample_weight is not None:
1298
+ if not isinstance(sample_weight, torch.Tensor):
1299
+ sample_weight = torch.from_numpy(sample_weight).to('cuda')
1300
+ sqrt_sw = torch.sqrt(sample_weight)
1301
+ X = X * sqrt_sw[:, None]
1302
+ y = y * sqrt_sw
1303
+
1304
+ # Ensure vector y on GPU
1305
+ y = y.reshape(-1)
1306
+
1307
+ # Center for intercept
1308
+ if self.fit_intercept:
1309
+ X_mean = torch.mean(X, dim=0)
1310
+ y_mean = torch.mean(y)
1311
+ X_centered = X - X_mean
1312
+ y_centered = y - y_mean
1313
+ else:
1314
+ X_centered = X
1315
+ y_mean = torch.tensor(0.0, dtype=X.dtype, device=X.device)
1316
+ y_centered = y
1317
+
1318
+ # Precompute XtX / Xty for FISTA gradient
1319
+ XtX = X_centered.T @ X_centered
1320
+ Xty = X_centered.T @ y_centered
1321
+
1322
+ # Lipschitz constant L
1323
+ if self.lipschitz_L is not None:
1324
+ L = torch.tensor(float(self.lipschitz_L), dtype=X.dtype, device=X.device)
1325
+ else:
1326
+ L_frob = torch.sum(X_centered ** 2) / n_samples
1327
+ try:
1328
+ eigvals = torch.linalg.eigvalsh(XtX)
1329
+ L = eigvals[-1] / n_samples
1330
+ except Exception:
1331
+ L = L_frob
1332
+
1333
+ if L <= 0:
1334
+ coef = torch.zeros(n_features, dtype=X.dtype, device=X.device)
1335
+ self.n_iter_ = 0
1336
+ else:
1337
+ step = 1.0 / L
1338
+ thresh = self.alpha * step
1339
+
1340
+ # FISTA variables
1341
+ coef = torch.zeros(n_features, dtype=X.dtype, device=X.device)
1342
+ y_k = coef.clone()
1343
+ t_k = torch.tensor(1.0, dtype=X.dtype, device=X.device)
1344
+
1345
+ for iteration in range(self.max_iter):
1346
+ coef_old = coef.clone()
1347
+
1348
+ # Gradient at y_k
1349
+ grad = (XtX @ y_k - Xty) / n_samples
1350
+
1351
+ # Prox step for L1
1352
+ coef = self._soft_threshold_torch(y_k - step * grad, thresh)
1353
+
1354
+ # Momentum update
1355
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_k ** 2))) / 2.0
1356
+ beta = (t_k - 1.0) / t_new
1357
+ y_k = coef + beta * (coef - coef_old)
1358
+ t_k = t_new
1359
+
1360
+ # Convergence test
1361
+ if self.stopping == "kkt":
1362
+ grad_sse = (XtX @ coef - Xty) / n_samples
1363
+ violation = torch.max(torch.maximum(torch.abs(grad_sse) - self.alpha, torch.tensor(0.0, dtype=X.dtype, device=X.device)))
1364
+ if violation < self.tol:
1365
+ self.n_iter_ = iteration + 1
1366
+ break
1367
+ else:
1368
+ if torch.sum(torch.abs(coef - coef_old)) < self.tol:
1369
+ self.n_iter_ = iteration + 1
1370
+ break
1371
+ else:
1372
+ self.n_iter_ = self.max_iter
1373
+
1374
+ # Build full coefficients
1375
+ if self.fit_intercept:
1376
+ intercept_torch = y_mean - X_mean @ coef
1377
+ coef_full = torch.cat([intercept_torch.reshape(1), coef])
1378
+ else:
1379
+ coef_full = coef
1380
+
1381
+ # Transfer coefficients to CPU
1382
+ coef_full_np = coef_full.cpu().numpy()
1383
+
1384
+ if self.fit_intercept:
1385
+ self.intercept_ = float(coef_full_np[0])
1386
+ self.coef_ = coef_full_np[1:]
1387
+ self._params = coef_full_np
1388
+ else:
1389
+ self.intercept_ = 0.0
1390
+ self.coef_ = coef_full_np
1391
+ self._params = coef_full_np
1392
+
1393
+ df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
1394
+ self._df_resid = df_resid
1395
+
1396
+ # Inference/diagnostics
1397
+ if self.compute_inference:
1398
+ if self.fit_intercept:
1399
+ X_design = torch.cat([torch.ones((n_samples, 1), dtype=X.dtype, device=X.device), X], dim=1)
1400
+ else:
1401
+ X_design = X
1402
+
1403
+ y_pred = X_design @ coef_full
1404
+ resid = y - y_pred
1405
+
1406
+ if df_resid > 0:
1407
+ scale = torch.sum(resid ** 2) / df_resid
1408
+ self._scale = float(scale.cpu().numpy()) if not torch.isnan(scale) else np.nan
1409
+ else:
1410
+ self._scale = np.nan
1411
+ scale = torch.tensor(np.nan, dtype=X.dtype, device=X.device)
1412
+
1413
+ if self.inference_method == "gpu_ols_inference":
1414
+ # Compute inference fully on GPU
1415
+ XtX_inf = X_design.T @ X_design
1416
+ try:
1417
+ XtX_inv = torch.linalg.inv(XtX_inf)
1418
+ except Exception:
1419
+ XtX_inv = torch.linalg.pinv(XtX_inf)
1420
+
1421
+ bse_gpu = torch.sqrt(scale * torch.diag(XtX_inv))
1422
+ params_gpu = coef_full
1423
+ tvalues_gpu = params_gpu / (bse_gpu + 1e-30)
1424
+
1425
+ from statgpu.inference._distributions_backend import get_distribution
1426
+ t_dist = get_distribution("t", backend="torch", device=str(X.device))
1427
+ pvalues_gpu = torch.minimum(torch.tensor(1.0, device=X.device), 2.0 * t_dist.sf(torch.abs(tvalues_gpu), df=df_resid))
1428
+
1429
+ alpha = 0.05
1430
+ t_crit_gpu = t_dist.ppf(1.0 - alpha / 2.0, df=df_resid)
1431
+ margin_gpu = t_crit_gpu * bse_gpu
1432
+ conf_int_gpu = torch.stack([params_gpu - margin_gpu, params_gpu + margin_gpu], dim=1)
1433
+
1434
+ # Transfer to CPU
1435
+ self._bse = bse_gpu.cpu().numpy()
1436
+ self._tvalues = tvalues_gpu.cpu().numpy()
1437
+ self._pvalues = pvalues_gpu.cpu().numpy()
1438
+ self._conf_int = conf_int_gpu.cpu().numpy()
1439
+
1440
+ # R^2
1441
+ y_mean_gpu = torch.mean(y)
1442
+ ss_tot = torch.sum((y - y_mean_gpu) ** 2)
1443
+ ss_res = torch.sum(resid ** 2)
1444
+ self._rsquared_gpu = float((1 - ss_res / ss_tot).cpu().numpy()) if ss_tot > 0 else 0.0
1445
+
1446
+ self._resid = None
1447
+ self._X_design = None
1448
+ elif self.inference_method == "debiased":
1449
+ # Debiased Lasso inference on Torch GPU
1450
+ self._compute_inference_debiased_torch(X, y, coef)
1451
+
1452
+ # R^2 computation
1453
+ y_mean_gpu = torch.mean(y)
1454
+ ss_tot = torch.sum((y - y_mean_gpu) ** 2)
1455
+ ss_res = torch.sum(resid ** 2)
1456
+ self._rsquared_gpu = float((1 - ss_res / ss_tot).cpu().numpy()) if ss_tot > 0 else 0.0
1457
+
1458
+ self._resid = None
1459
+ self._X_design = None
1460
+ else:
1461
+ raise NotImplementedError(
1462
+ f"Lasso inference_method='{self.inference_method}' is not implemented "
1463
+ "for Torch without CPU fallback."
1464
+ )
1465
+ else:
1466
+ self._scale = np.nan
1467
+ self._resid = None
1468
+ self._X_design = None
1469
+ self._rsquared_gpu = None
1470
+
1471
+ # Cleanup
1472
+ try:
1473
+ del X_design
1474
+ except Exception:
1475
+ pass
1476
+ try:
1477
+ del resid
1478
+ except Exception:
1479
+ pass
1480
+ try:
1481
+ del XtX
1482
+ except Exception:
1483
+ pass
1484
+ try:
1485
+ del Xty
1486
+ except Exception:
1487
+ pass
1488
+ try:
1489
+ del X_centered
1490
+ except Exception:
1491
+ pass
1492
+ try:
1493
+ del y_centered
1494
+ except Exception:
1495
+ pass
1496
+ try:
1497
+ del y_pred
1498
+ except Exception:
1499
+ pass
1500
+ try:
1501
+ del coef_full
1502
+ except Exception:
1503
+ pass
1504
+ self._cleanup_torch_memory()
1505
+
1506
+ def _fit_gpu_admm(self, X, y, sample_weight=None):
1507
+ """Fit using GPU with ADMM solver.
1508
+
1509
+ Objective matches sklearn:
1510
+ (1/(2n)) * ||y - Xw||^2 + alpha * ||w||_1
1511
+ """
1512
+ import cupy as cp
1513
+ import cupyx.scipy.linalg as cpx_linalg
1514
+
1515
+ n_samples, n_features = X.shape
1516
+ self._nobs = n_samples
1517
+
1518
+ # Ensure CuPy arrays
1519
+ X = cp.asarray(X)
1520
+ y = cp.asarray(y)
1521
+
1522
+ if sample_weight is not None:
1523
+ sample_weight = cp.asarray(sample_weight)
1524
+ sqrt_sw = cp.sqrt(sample_weight)
1525
+ X = X * sqrt_sw[:, cp.newaxis]
1526
+ y = y * sqrt_sw
1527
+
1528
+ # Ensure vector y on GPU
1529
+ y = y.reshape(-1)
1530
+
1531
+ # Center for intercept
1532
+ if self.fit_intercept:
1533
+ X_mean = cp.mean(X, axis=0)
1534
+ y_mean = cp.mean(y)
1535
+ X_centered = X - X_mean
1536
+ y_centered = y - y_mean
1537
+ else:
1538
+ X_centered = X
1539
+ y_mean = cp.array(0.0, dtype=X.dtype)
1540
+ y_centered = y
1541
+
1542
+ # ADMM variables for constraint w=z
1543
+ coef = cp.zeros(n_features, dtype=X.dtype) # w
1544
+ z = cp.zeros(n_features, dtype=X.dtype) # z
1545
+ u = cp.zeros(n_features, dtype=X.dtype) # scaled dual
1546
+
1547
+ # Precompute XtX and Xty
1548
+ XtX = X_centered.T @ X_centered
1549
+ Xty = X_centered.T @ y_centered
1550
+
1551
+ # w-update solves:
1552
+ # (XtX + rho*n*I) w = Xty + rho*n * (z - u)
1553
+ rho = float(self.admm_rho)
1554
+ if rho <= 0:
1555
+ raise ValueError("admm_rho must be > 0")
1556
+
1557
+ lhs = XtX + (rho * n_samples) * cp.eye(n_features, dtype=X.dtype)
1558
+
1559
+ # Pre-factorize once
1560
+ Lmat = cp.linalg.cholesky(lhs)
1561
+
1562
+ def solve_w(rhs):
1563
+ # Solve Lmat @ (Lmat.T @ w) = rhs
1564
+ tmp = cpx_linalg.solve_triangular(Lmat, rhs, lower=True)
1565
+ return cpx_linalg.solve_triangular(Lmat.T, tmp, lower=False)
1566
+
1567
+ thresh = self.alpha / rho
1568
+
1569
+ for iteration in range(self.max_iter):
1570
+ coef_old = coef
1571
+
1572
+ rhs = Xty + (rho * n_samples) * (z - u)
1573
+ coef = solve_w(rhs)
1574
+
1575
+ # z-update (prox of l1)
1576
+ z_old = z
1577
+ z = self._soft_threshold_cupy(coef + u, thresh)
1578
+
1579
+ # dual update
1580
+ u = u + (coef - z)
1581
+
1582
+ # Convergence test
1583
+ if self.stopping == "kkt":
1584
+ grad_sse = (XtX @ coef - Xty) / n_samples
1585
+ violation = cp.max(cp.maximum(cp.abs(grad_sse) - self.alpha, 0.0))
1586
+ if violation < self.tol:
1587
+ self.n_iter_ = iteration + 1
1588
+ break
1589
+ else:
1590
+ # Legacy stopping: coefficient delta
1591
+ if cp.sum(cp.abs(coef - coef_old)) < self.tol:
1592
+ self.n_iter_ = iteration + 1
1593
+ break
1594
+ z = z # keep for clarity
1595
+ else:
1596
+ self.n_iter_ = self.max_iter
1597
+
1598
+ # Build full coefficients and (optionally) residuals for inference/R^2
1599
+ if self.fit_intercept:
1600
+ intercept_gpu = y_mean - X_mean @ coef
1601
+ coef_full = cp.concatenate([intercept_gpu.reshape(1), coef])
1602
+ X_design = cp.concatenate([cp.ones((n_samples, 1), dtype=X.dtype), X], axis=1)
1603
+ else:
1604
+ coef_full = coef
1605
+ X_design = X
1606
+
1607
+ coef_full_np = coef_full.get()
1608
+ if self.fit_intercept:
1609
+ self.intercept_ = float(coef_full_np[0])
1610
+ self.coef_ = coef_full_np[1:]
1611
+ self._params = coef_full_np
1612
+ else:
1613
+ self.intercept_ = 0.0
1614
+ self.coef_ = coef_full_np
1615
+ self._params = coef_full_np
1616
+
1617
+ df_resid = n_samples - (n_features + (1 if self.fit_intercept else 0))
1618
+ self._df_resid = df_resid
1619
+
1620
+ if self.compute_inference:
1621
+ y_pred = X_design @ coef_full
1622
+ resid = y - y_pred
1623
+ if df_resid > 0:
1624
+ scale = cp.sum(resid ** 2) / df_resid
1625
+ self._scale = float(scale.get()) if not cp.isnan(scale) else np.nan
1626
+ else:
1627
+ self._scale = np.nan
1628
+ scale = cp.nan
1629
+
1630
+ if self.inference_method == "gpu_ols_inference":
1631
+ # Keep the inference path on GPU and transfer only small vectors.
1632
+ XtX_inf = X_design.T @ X_design
1633
+ try:
1634
+ XtX_inv = cp.linalg.inv(XtX_inf)
1635
+ except Exception:
1636
+ XtX_inv = cp.linalg.pinv(XtX_inf)
1637
+
1638
+ bse_gpu = cp.sqrt(scale * cp.diag(XtX_inv))
1639
+ params_gpu = coef_full
1640
+ tvalues_gpu = params_gpu / (bse_gpu + 1e-30)
1641
+ pvalues_gpu = cp.minimum(1.0, 2.0 * t.sf(cp.abs(tvalues_gpu), df=df_resid))
1642
+
1643
+ alpha = 0.05
1644
+ t_crit_gpu = t.ppf(1.0 - alpha / 2.0, df=df_resid)
1645
+ margin_gpu = t_crit_gpu * bse_gpu
1646
+ conf_int_gpu = cp.stack([params_gpu - margin_gpu, params_gpu + margin_gpu], axis=1)
1647
+
1648
+ self._bse = cp.asnumpy(bse_gpu)
1649
+ self._tvalues = cp.asnumpy(tvalues_gpu)
1650
+ self._pvalues = cp.asnumpy(pvalues_gpu)
1651
+ self._conf_int = cp.asnumpy(conf_int_gpu)
1652
+
1653
+ y_mean_gpu = cp.mean(y)
1654
+ ss_tot = cp.sum((y - y_mean_gpu) ** 2)
1655
+ ss_res = cp.sum(resid ** 2)
1656
+ self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
1657
+
1658
+ self._resid = None
1659
+ self._X_design = None
1660
+ elif self.inference_method == "debiased":
1661
+ self._compute_inference_debiased_gpu(X, y, coef)
1662
+
1663
+ y_mean_gpu = cp.mean(y)
1664
+ ss_tot = cp.sum((y - y_mean_gpu) ** 2)
1665
+ ss_res = cp.sum(resid ** 2)
1666
+ self._rsquared_gpu = float(cp.asnumpy(1 - ss_res / ss_tot)) if ss_tot > 0 else 0.0
1667
+
1668
+ self._resid = None
1669
+ self._X_design = None
1670
+ else:
1671
+ raise NotImplementedError(
1672
+ f"Lasso inference_method='{self.inference_method}' is not implemented "
1673
+ "for CuPy without CPU fallback."
1674
+ )
1675
+ else:
1676
+ self._scale = np.nan
1677
+ self._resid = None
1678
+ self._X_design = None
1679
+ self._rsquared_gpu = None
1680
+
1681
+ # Drop large temporaries early (before optional pool cleanup).
1682
+ try:
1683
+ del X_design
1684
+ except Exception:
1685
+ pass
1686
+ try:
1687
+ del resid
1688
+ except Exception:
1689
+ pass
1690
+ try:
1691
+ del XtX
1692
+ except Exception:
1693
+ pass
1694
+ try:
1695
+ del Xty
1696
+ except Exception:
1697
+ pass
1698
+ try:
1699
+ del X_centered
1700
+ except Exception:
1701
+ pass
1702
+ try:
1703
+ del y_centered
1704
+ except Exception:
1705
+ pass
1706
+ try:
1707
+ del y_pred
1708
+ except Exception:
1709
+ pass
1710
+ try:
1711
+ del coef_full
1712
+ except Exception:
1713
+ pass
1714
+ try:
1715
+ del lhs
1716
+ except Exception:
1717
+ pass
1718
+ try:
1719
+ del Lmat
1720
+ except Exception:
1721
+ pass
1722
+ self._cleanup_cuda_memory()
1723
+
1724
+ def _compute_inference(self):
1725
+ """Compute standard errors, t-stats, p-values."""
1726
+ if self.inference_method == "bootstrap":
1727
+ return self._compute_inference_bootstrap()
1728
+ if self.inference_method == "debiased":
1729
+ return self._compute_inference_debiased()
1730
+ if self.inference_method == "gpu_ols_inference":
1731
+ # Inference already computed on GPU in _fit_gpu().
1732
+ return
1733
+ if self._X_design is None or self._scale is None or np.isnan(self._scale):
1734
+ return
1735
+
1736
+ X = self._X_design
1737
+
1738
+ try:
1739
+ XtX_inv = np.linalg.inv(X.T @ X)
1740
+ except np.linalg.LinAlgError:
1741
+ XtX_inv = np.linalg.pinv(X.T @ X)
1742
+
1743
+ self._bse = np.sqrt(self._scale * np.diag(XtX_inv))
1744
+ self._tvalues = self._params / self._bse
1745
+ self._pvalues = 2 * (1 - stats.t.cdf(np.abs(self._tvalues), self._df_resid))
1746
+
1747
+ alpha = 0.05
1748
+ t_crit = stats.t.ppf(1 - alpha/2, self._df_resid)
1749
+ self._conf_int = np.column_stack([
1750
+ self._params - t_crit * self._bse,
1751
+ self._params + t_crit * self._bse
1752
+ ])
1753
+
1754
+ def _compute_inference_bootstrap(self) -> None:
1755
+ """
1756
+ Bootstrap inference for Lasso via residual resampling.
1757
+
1758
+ Notes
1759
+ -----
1760
+ This is more robust than the naive OLS-based inference, but it is still
1761
+ not full "post-selection inference" for Lasso.
1762
+ """
1763
+ if self._X_design is None or self._resid is None or self._y is None:
1764
+ return
1765
+
1766
+ if self.n_bootstrap <= 0:
1767
+ return
1768
+
1769
+ rng = np.random.default_rng(self.bootstrap_random_state)
1770
+ X = self._X_design
1771
+ y = self._y
1772
+ y_pred = y - self._resid
1773
+ resid = self._resid
1774
+
1775
+ params_dim = self._params.shape[0]
1776
+ boot_params = np.zeros((self.n_bootstrap, params_dim), dtype=float)
1777
+
1778
+ # Precompute Lipschitz constant if needed for CPU FISTA.
1779
+ lipschitz_L = self.lipschitz_L
1780
+ if self.cpu_solver == "fista" and lipschitz_L is None:
1781
+ # L = lambda_max(Xc^T Xc) / n for centered design
1782
+ X_nopen = X[:, 1:] if self.fit_intercept else X
1783
+ X_centered = X_nopen - X_nopen.mean(axis=0, keepdims=True)
1784
+ XtX = X_centered.T @ X_centered
1785
+ eigvals = np.linalg.eigvalsh(XtX)
1786
+ lipschitz_L = float(eigvals[-1] / X_nopen.shape[0])
1787
+
1788
+ for b in range(self.n_bootstrap):
1789
+ eps_star = rng.choice(resid, size=resid.shape[0], replace=True)
1790
+ y_star = y_pred + eps_star
1791
+
1792
+ refit = Lasso(
1793
+ alpha=self.alpha,
1794
+ fit_intercept=self.fit_intercept,
1795
+ max_iter=self.max_iter,
1796
+ tol=self.tol,
1797
+ stopping=self.stopping,
1798
+ inference_method="cpu_ols_inference",
1799
+ n_bootstrap=0,
1800
+ bootstrap_random_state=None,
1801
+ device="cpu",
1802
+ compute_inference=False,
1803
+ solver=self.solver,
1804
+ cpu_solver=self.cpu_solver,
1805
+ lipschitz_L=lipschitz_L,
1806
+ admm_rho=self.admm_rho,
1807
+ )
1808
+
1809
+ # Refit expects raw X (without intercept column).
1810
+ if self.fit_intercept:
1811
+ X_refit = X[:, 1:]
1812
+ else:
1813
+ X_refit = X
1814
+
1815
+ refit.fit(X_refit, y_star)
1816
+ boot_params[b, :] = refit._params
1817
+
1818
+ # Standard errors and bootstrap-based p-values/CI.
1819
+ self._bse = np.std(boot_params, axis=0, ddof=1)
1820
+ self._params = np.asarray(self._params, dtype=float)
1821
+
1822
+ # Two-sided p-values using sign-change probability.
1823
+ pvalues = np.zeros(params_dim, dtype=float)
1824
+ for i in range(params_dim):
1825
+ coef_b = boot_params[:, i]
1826
+ p_lower = np.mean(coef_b <= 0.0)
1827
+ p_upper = np.mean(coef_b >= 0.0)
1828
+ p = 2.0 * min(p_lower, p_upper)
1829
+ pvalues[i] = min(p, 1.0)
1830
+ self._pvalues = pvalues
1831
+
1832
+ # Percentile confidence intervals.
1833
+ lower_q = (0.05 / 2.0) * 1.0
1834
+ upper_q = 1.0 - (0.05 / 2.0) * 1.0
1835
+ self._conf_int = np.column_stack([
1836
+ np.quantile(boot_params, lower_q, axis=0),
1837
+ np.quantile(boot_params, upper_q, axis=0),
1838
+ ])
1839
+
1840
+ # t-stats (approx) from bootstrap SE.
1841
+ self._tvalues = self._params / (self._bse + 1e-30)
1842
+
1843
+ def _compute_inference_debiased(self) -> None:
1844
+ """Debiased Lasso inference (Javanmard-Montanari / Zhang-Zhang).
1845
+
1846
+ Constructs the decorrelation matrix M via node-wise Lasso,
1847
+ then computes the debiased estimator, standard errors,
1848
+ z-statistics, p-values, and per-coefficient (marginal)
1849
+ confidence intervals.
1850
+ """
1851
+ if self._X_design is None or self._resid is None:
1852
+ return
1853
+
1854
+ if self.fit_intercept:
1855
+ X = self._X_design[:, 1:]
1856
+ else:
1857
+ X = self._X_design
1858
+
1859
+ n, p = X.shape
1860
+ coef = self.coef_.copy()
1861
+
1862
+ Sigma_hat = X.T @ X / n
1863
+ resid_lasso = self._resid
1864
+
1865
+ # --- noise variance: sigma^2 = RSS / (n - s_hat) ---
1866
+ s_hat = int(np.sum(np.abs(coef) > 0))
1867
+ denom = max(n - s_hat, 1)
1868
+ sigma2 = np.sum(resid_lasso ** 2) / denom
1869
+
1870
+ # --- node-wise Lasso to build M (p x p), with cross-fit cache ---
1871
+ lam_nw = np.sqrt(2.0 * np.log(max(p, 2)) / n)
1872
+ m_cache_key = _debiased_m_key_from_numpy_design(
1873
+ X,
1874
+ n=n,
1875
+ p=p,
1876
+ lam_nw=lam_nw,
1877
+ tol=float(self.tol),
1878
+ )
1879
+ M_cached = _debiased_m_cache_get(m_cache_key)
1880
+ if M_cached is not None:
1881
+ M = np.asarray(M_cached, dtype=X.dtype)
1882
+ else:
1883
+ M = np.zeros((p, p), dtype=X.dtype)
1884
+ for j in range(p):
1885
+ cols = np.concatenate([np.arange(0, j), np.arange(j + 1, p)])
1886
+ X_minus_j = X[:, cols]
1887
+ x_j = X[:, j]
1888
+
1889
+ nw = Lasso(
1890
+ alpha=lam_nw,
1891
+ fit_intercept=False,
1892
+ max_iter=500,
1893
+ tol=1e-5,
1894
+ device="cpu",
1895
+ cpu_solver="fista",
1896
+ compute_inference=False,
1897
+ )
1898
+ nw.fit(X_minus_j, x_j)
1899
+ gamma_j = nw.coef_
1900
+
1901
+ z_j = x_j - X_minus_j @ gamma_j
1902
+ C_j = z_j @ x_j / n
1903
+
1904
+ if abs(C_j) < 1e-30:
1905
+ M[j, j] = 1.0
1906
+ continue
1907
+
1908
+ M[j, j] = 1.0 / C_j
1909
+ M[j, cols] = -gamma_j / C_j
1910
+ _debiased_m_cache_put(m_cache_key, np.asarray(M, dtype=np.float64))
1911
+
1912
+ # --- debiased estimates ---
1913
+ theta_db = coef + (M @ X.T @ resid_lasso) / n
1914
+ self._debiased_M_cpu = M
1915
+
1916
+ # --- covariance and standard errors ---
1917
+ V = M @ Sigma_hat @ M.T
1918
+ se = np.sqrt(sigma2 * np.diag(V) / n)
1919
+
1920
+ z_stats = theta_db / (se + 1e-30)
1921
+ pvalues = 2.0 * (1.0 - _norm_dist.cdf(np.abs(z_stats)))
1922
+
1923
+ alpha_ci = 0.05
1924
+ z_crit = _norm_dist.ppf(1.0 - alpha_ci / 2.0)
1925
+ ci = np.column_stack([theta_db - z_crit * se, theta_db + z_crit * se])
1926
+
1927
+ if self.fit_intercept:
1928
+ # Intercept SE via OLS formula: sigma * sqrt([1/n + xbar' (X'X)^-1 xbar])
1929
+ X_full = self._X_design
1930
+ try:
1931
+ XtX_inv = np.linalg.inv(X_full.T @ X_full)
1932
+ except np.linalg.LinAlgError:
1933
+ XtX_inv = np.linalg.pinv(X_full.T @ X_full)
1934
+ se_intercept = np.sqrt(sigma2 * XtX_inv[0, 0])
1935
+ z_intercept = self.intercept_ / (se_intercept + 1e-30)
1936
+ p_intercept = 2.0 * (1.0 - _norm_dist.cdf(np.abs(z_intercept)))
1937
+ ci_intercept = np.array([
1938
+ self.intercept_ - z_crit * se_intercept,
1939
+ self.intercept_ + z_crit * se_intercept,
1940
+ ])
1941
+
1942
+ self._bse = np.concatenate([[se_intercept], se])
1943
+ self._tvalues = np.concatenate([[z_intercept], z_stats])
1944
+ self._pvalues = np.concatenate([[p_intercept], pvalues])
1945
+ self._conf_int = np.vstack([ci_intercept[np.newaxis, :], ci])
1946
+ self._params = np.concatenate([[self.intercept_], theta_db])
1947
+ else:
1948
+ self._bse = se
1949
+ self._tvalues = z_stats
1950
+ self._pvalues = pvalues
1951
+ self._conf_int = ci
1952
+ self._params = theta_db
1953
+
1954
+ def _compute_inference_debiased_gpu(self, X_gpu, y_gpu, coef_gpu):
1955
+ """GPU path for debiased Lasso inference.
1956
+
1957
+ Parameters
1958
+ ----------
1959
+ X_gpu : cupy.ndarray, shape (n, p)
1960
+ Raw feature matrix on GPU (no intercept column).
1961
+ y_gpu : cupy.ndarray, shape (n,)
1962
+ Response on GPU.
1963
+ coef_gpu : cupy.ndarray, shape (p,)
1964
+ Lasso coefficients on GPU (no intercept).
1965
+ """
1966
+ import cupy as cp
1967
+
1968
+ n, p = X_gpu.shape
1969
+ Sigma_hat = X_gpu.T @ X_gpu / n
1970
+
1971
+ resid_lasso = y_gpu - X_gpu @ coef_gpu
1972
+ if self.fit_intercept:
1973
+ resid_lasso = resid_lasso - cp.mean(y_gpu) + cp.mean(X_gpu, axis=0) @ coef_gpu
1974
+
1975
+ s_hat_gpu = cp.sum(cp.abs(coef_gpu) > 0).astype(cp.float64)
1976
+ denom_gpu = cp.maximum(1.0, float(n) - s_hat_gpu)
1977
+ sigma2_gpu = cp.asarray(cp.sum(resid_lasso ** 2) / denom_gpu, dtype=cp.float64)
1978
+
1979
+ lam_nw = float(np.sqrt(2.0 * np.log(max(p, 2)) / n))
1980
+ alpha_nw = np.asarray([lam_nw], dtype=np.float64)
1981
+ tiny = X_gpu.dtype.type(1e-30)
1982
+ zero = X_gpu.dtype.type(0.0)
1983
+ one = X_gpu.dtype.type(1.0)
1984
+
1985
+ # Keep node-wise Lasso solves on GPU to avoid per-feature host round-trips.
1986
+ x_hasher = hashlib.blake2b(digest_size=32)
1987
+ x_hasher.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
1988
+ x_hasher.update(str(X_gpu.dtype).encode("utf-8"))
1989
+ x_hasher.update(np.asarray([float(lam_nw), float(self.tol)], dtype=np.float64).tobytes())
1990
+ row_chunk = max(1, min(int(n), _LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK))
1991
+ for start in range(0, int(n), row_chunk):
1992
+ stop = min(int(n), start + row_chunk)
1993
+ x_chunk = cp.asnumpy(X_gpu[start:stop])
1994
+ x_hasher.update(x_chunk.tobytes())
1995
+ m_cache_key = x_hasher.hexdigest()
1996
+ M_cached = _debiased_m_cache_get(m_cache_key)
1997
+ if M_cached is not None:
1998
+ M = cp.asarray(M_cached, dtype=X_gpu.dtype)
1999
+ else:
2000
+ M = cp.zeros((p, p), dtype=X_gpu.dtype)
2001
+ # Reuse full Gram to avoid repeated X_minus_j.T @ X_minus_j products.
2002
+ XtX_full = X_gpu.T @ X_gpu
2003
+ Sigma_diag = cp.diag(Sigma_hat)
2004
+ n_samp_vec_dtype = np.float64
2005
+
2006
+ # Batch node-wise problems so GPU can process many j's together.
2007
+ try:
2008
+ free_mem, _ = cp.cuda.Device().mem_info
2009
+ bytes_per_fold = int(max(8, (p - 1) * (p - 1) * 8 * 2))
2010
+ chunk_size = int(max(4, min(64, free_mem // max(bytes_per_fold, 1))))
2011
+ except Exception:
2012
+ chunk_size = 16
2013
+ chunk_size = max(4, min(int(p), chunk_size))
2014
+
2015
+ for j0 in range(0, p, chunk_size):
2016
+ j1 = min(p, j0 + chunk_size)
2017
+ bsz = j1 - j0
2018
+ j_batch = cp.arange(j0, j1, dtype=cp.int32)
2019
+ if int(j_batch.size) == 0:
2020
+ continue
2021
+
2022
+ # Build per-j "all except j" column index matrix of shape (bsz, p-1).
2023
+ base = cp.arange(p - 1, dtype=cp.int32).reshape(1, -1)
2024
+ cols_batch = base + (base >= j_batch.reshape(-1, 1))
2025
+
2026
+ # Gather batched Gram/Xty blocks.
2027
+ XtX_batch = XtX_full[
2028
+ cols_batch[:, :, cp.newaxis],
2029
+ cols_batch[:, cp.newaxis, :],
2030
+ ]
2031
+ Xty_batch = XtX_full[cols_batch, j_batch.reshape(-1, 1)].reshape(bsz, p - 1)
2032
+
2033
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
2034
+ XtX_batch,
2035
+ Xty_batch,
2036
+ n_samples_vec=np.full((bsz,), float(n), dtype=n_samp_vec_dtype),
2037
+ alphas_desc=alpha_nw,
2038
+ max_iter=500,
2039
+ tol=1e-5,
2040
+ stopping="coef_delta",
2041
+ lipschitz_L=None,
2042
+ check_every=8,
2043
+ )
2044
+ gamma_batch = cp.asarray(coefs_batch_desc[:, 0, :], dtype=X_gpu.dtype)
2045
+
2046
+ # C_j = Sigma_jj - Sigma_{j,-j} gamma_j
2047
+ sigma_j_cols = Sigma_hat[j_batch[:, cp.newaxis], cols_batch]
2048
+ C_batch = Sigma_diag[j_batch] - cp.sum(sigma_j_cols * gamma_batch, axis=1)
2049
+
2050
+ small_c = cp.abs(C_batch) < tiny
2051
+ inv_c = cp.where(small_c, zero, one / C_batch)
2052
+ M[j_batch, j_batch] = cp.where(small_c, one, inv_c)
2053
+ M[j_batch[:, cp.newaxis], cols_batch] = -gamma_batch * inv_c.reshape(-1, 1)
2054
+
2055
+ del XtX_batch
2056
+ del Xty_batch
2057
+ del coefs_batch_desc
2058
+ del gamma_batch
2059
+ del sigma_j_cols
2060
+ _debiased_m_cache_put(m_cache_key, cp.asnumpy(M))
2061
+
2062
+ # Recompute full residual from the original fit
2063
+ if self.fit_intercept:
2064
+ y_pred = X_gpu @ coef_gpu + cp.asarray(self.intercept_, dtype=X_gpu.dtype)
2065
+ else:
2066
+ y_pred = X_gpu @ coef_gpu
2067
+ resid_full = y_gpu - y_pred
2068
+
2069
+ theta_db = coef_gpu + (M @ X_gpu.T @ resid_full) / n
2070
+
2071
+ V = M @ Sigma_hat @ M.T
2072
+ se = cp.sqrt(sigma2_gpu * cp.diag(V) / n)
2073
+
2074
+ z_stats = theta_db / (se + 1e-30)
2075
+ pvalues = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_stats)))
2076
+
2077
+ alpha_ci = 0.05
2078
+ z_crit = norm.ppf(1.0 - alpha_ci / 2.0)
2079
+ ci = cp.stack([theta_db - z_crit * se, theta_db + z_crit * se], axis=1)
2080
+
2081
+ if self.fit_intercept:
2082
+ X_full = cp.concatenate(
2083
+ [cp.ones((n, 1), dtype=X_gpu.dtype), X_gpu], axis=1
2084
+ )
2085
+ XtX_full = X_full.T @ X_full
2086
+ try:
2087
+ XtX_inv = cp.linalg.inv(XtX_full)
2088
+ except Exception:
2089
+ XtX_inv = cp.linalg.pinv(XtX_full)
2090
+ se_intercept = cp.sqrt(sigma2_gpu * XtX_inv[0, 0])
2091
+ intercept_gpu = cp.asarray(self.intercept_, dtype=cp.float64)
2092
+ z_intercept = intercept_gpu / (se_intercept + 1e-30)
2093
+ p_intercept = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_intercept).reshape(1)))
2094
+ ci_intercept = cp.stack([
2095
+ intercept_gpu - z_crit * se_intercept,
2096
+ intercept_gpu + z_crit * se_intercept,
2097
+ ]).reshape(1, 2)
2098
+
2099
+ bse_gpu = cp.concatenate([se_intercept.reshape(1), se])
2100
+ tvalues_gpu = cp.concatenate([z_intercept.reshape(1), z_stats])
2101
+ pvalues_gpu = cp.concatenate([p_intercept.reshape(1), pvalues])
2102
+ conf_int_gpu = cp.concatenate([ci_intercept, ci], axis=0)
2103
+ params_gpu = cp.concatenate([intercept_gpu.reshape(1), theta_db])
2104
+ else:
2105
+ bse_gpu = se
2106
+ tvalues_gpu = z_stats
2107
+ pvalues_gpu = pvalues
2108
+ conf_int_gpu = ci
2109
+ params_gpu = theta_db
2110
+
2111
+ if self.enable_simultaneous_inference:
2112
+ # GPU-native simultaneous CI via max-|Z| multiplier bootstrap.
2113
+ param_target_idx_np = self._get_simultaneous_target_indices(
2114
+ int(params_gpu.shape[0])
2115
+ )
2116
+ param_target_idx_gpu = cp.asarray(param_target_idx_np, dtype=cp.int32)
2117
+ if param_target_idx_gpu.size == 0:
2118
+ raise RuntimeError(
2119
+ "No coefficients selected for simultaneous inference target set."
2120
+ )
2121
+
2122
+ feature_offset = 1 if self.fit_intercept else 0
2123
+ feature_target_gpu = param_target_idx_gpu - feature_offset
2124
+ feature_target_gpu = feature_target_gpu[feature_target_gpu >= 0]
2125
+ if feature_target_gpu.size == 0:
2126
+ raise RuntimeError(
2127
+ "No feature coefficients selected for simultaneous inference target set."
2128
+ )
2129
+
2130
+ se_feat_gpu = se
2131
+ B = int(self.simultaneous_n_bootstrap)
2132
+ rng = cp.random.RandomState(self.simultaneous_random_state)
2133
+ se_target_gpu = cp.take(se_feat_gpu, feature_target_gpu)
2134
+ M_target = cp.take(M, feature_target_gpu, axis=0)
2135
+ # Run bootstrap in one shot when memory allows to reduce kernel-launch overhead.
2136
+ try:
2137
+ xi = rng.standard_normal(size=(B, n)).astype(cp.float64, copy=False)
2138
+ weighted = xi * resid_full.reshape(1, -1)
2139
+ score_target = (weighted @ X_gpu) @ M_target.T / float(max(n, 1))
2140
+ z_star_target = score_target / (se_target_gpu.reshape(1, -1) + 1e-30)
2141
+ max_stats_gpu = cp.max(cp.abs(z_star_target), axis=1)
2142
+ except Exception:
2143
+ free_mem, _ = cp.cuda.Device().mem_info
2144
+ bytes_per_row = max(8 * (3 * n + 2 * p + 64), 8)
2145
+ est_chunk = int(max(64, min(4096, free_mem // bytes_per_row)))
2146
+ chunk = min(B, max(64, est_chunk))
2147
+ max_stats_gpu = cp.empty((B,), dtype=cp.float64)
2148
+ filled = 0
2149
+ while filled < B:
2150
+ bsz = min(chunk, B - filled)
2151
+ xi = rng.standard_normal(size=(bsz, n)).astype(cp.float64, copy=False)
2152
+ weighted = xi * resid_full.reshape(1, -1)
2153
+ score_target = (weighted @ X_gpu) @ M_target.T / float(max(n, 1))
2154
+ z_star_target = score_target / (se_target_gpu.reshape(1, -1) + 1e-30)
2155
+ max_stats_gpu[filled : filled + bsz] = cp.max(
2156
+ cp.abs(z_star_target), axis=1
2157
+ )
2158
+ filled += bsz
2159
+
2160
+ critical_gpu = cp.quantile(
2161
+ max_stats_gpu, 1.0 - float(self.simultaneous_alpha)
2162
+ )
2163
+ conf_sim_gpu = cp.array(conf_int_gpu, copy=True)
2164
+ lower_gpu = cp.take(params_gpu, param_target_idx_gpu) - critical_gpu * cp.take(
2165
+ bse_gpu, param_target_idx_gpu
2166
+ )
2167
+ upper_gpu = cp.take(params_gpu, param_target_idx_gpu) + critical_gpu * cp.take(
2168
+ bse_gpu, param_target_idx_gpu
2169
+ )
2170
+ conf_sim_gpu[param_target_idx_gpu, 0] = lower_gpu
2171
+ conf_sim_gpu[param_target_idx_gpu, 1] = upper_gpu
2172
+
2173
+ target_mask = np.zeros(int(params_gpu.shape[0]), dtype=bool)
2174
+ target_mask[param_target_idx_np] = True
2175
+ self._conf_int_simultaneous = cp.asnumpy(conf_sim_gpu)
2176
+ self._simultaneous_enabled = True
2177
+ self._simultaneous_method = self.simultaneous_method
2178
+ self._simultaneous_alpha = float(self.simultaneous_alpha)
2179
+ self._simultaneous_n_bootstrap = B
2180
+ self._simultaneous_critical_value = float(cp.asnumpy(critical_gpu))
2181
+ self._simultaneous_target_mask = target_mask
2182
+
2183
+ self._bse = cp.asnumpy(bse_gpu)
2184
+ self._tvalues = cp.asnumpy(tvalues_gpu)
2185
+ self._pvalues = cp.asnumpy(pvalues_gpu)
2186
+ self._conf_int = cp.asnumpy(conf_int_gpu)
2187
+ self._params = cp.asnumpy(params_gpu)
2188
+
2189
+ def _get_simultaneous_target_indices(self, n_params: int):
2190
+ if self.fit_intercept and (not self.simultaneous_include_intercept):
2191
+ return np.arange(1, n_params, dtype=int)
2192
+ return np.arange(n_params, dtype=int)
2193
+
2194
+ def _compute_simultaneous_inference(self):
2195
+ if not self.enable_simultaneous_inference:
2196
+ return
2197
+ if self._simultaneous_enabled and self._conf_int_simultaneous is not None:
2198
+ return
2199
+ if self.inference_method != "debiased":
2200
+ return
2201
+ if self._params is None or self._bse is None or self._conf_int is None:
2202
+ return
2203
+ if self._X_design is None or self._resid is None:
2204
+ raise RuntimeError(
2205
+ "Simultaneous debiased inference requires accessible design/residual "
2206
+ "state; re-fit with compute_inference=True."
2207
+ )
2208
+ self._compute_simultaneous_ci_maxz_bootstrap()
2209
+
2210
+ def compute_debiased_inference(self):
2211
+ """Explicitly recompute debiased inference for a fitted model."""
2212
+ self._check_is_fitted()
2213
+ if self.inference_method != "debiased":
2214
+ raise ValueError("compute_debiased_inference requires inference_method='debiased'.")
2215
+ self._compute_inference()
2216
+ return self
2217
+
2218
+ def compute_debiased_inference_(self):
2219
+ """Deprecated alias for :meth:`compute_debiased_inference`."""
2220
+ warnings.warn(
2221
+ "compute_debiased_inference_ is deprecated and will be removed in a future "
2222
+ "release; use compute_debiased_inference instead.",
2223
+ DeprecationWarning,
2224
+ stacklevel=2,
2225
+ )
2226
+ return self.compute_debiased_inference()
2227
+
2228
+ def compute_simultaneous_inference(self):
2229
+ """Explicitly (re)compute simultaneous inference for a fitted model."""
2230
+ self._check_is_fitted()
2231
+ if not self.enable_simultaneous_inference:
2232
+ raise ValueError(
2233
+ "compute_simultaneous_inference requires enable_simultaneous_inference=True."
2234
+ )
2235
+ self._compute_simultaneous_inference()
2236
+ return self
2237
+
2238
+ def compute_simultaneous_inference_(self):
2239
+ """Deprecated alias for :meth:`compute_simultaneous_inference`."""
2240
+ warnings.warn(
2241
+ "compute_simultaneous_inference_ is deprecated and will be removed in a "
2242
+ "future release; use compute_simultaneous_inference instead.",
2243
+ DeprecationWarning,
2244
+ stacklevel=2,
2245
+ )
2246
+ return self.compute_simultaneous_inference()
2247
+
2248
+ def _compute_simultaneous_ci_maxz_bootstrap(self):
2249
+ """Compute simultaneous CIs using max-|Z| multiplier bootstrap."""
2250
+ # Feature-only design used by debiased estimator.
2251
+ if self.fit_intercept:
2252
+ X = np.asarray(self._X_design[:, 1:], dtype=float)
2253
+ else:
2254
+ X = np.asarray(self._X_design, dtype=float)
2255
+ resid = np.asarray(self._resid, dtype=float).reshape(-1)
2256
+ n, p = X.shape
2257
+ if p == 0:
2258
+ raise RuntimeError("Simultaneous inference requires at least one feature.")
2259
+
2260
+ # Reuse M from debiased inference when available to avoid duplicate node-wise solves.
2261
+ M = self._debiased_M_cpu
2262
+ if M is None or M.shape != (p, p):
2263
+ lam_nw = np.sqrt(2.0 * np.log(max(p, 2)) / max(n, 1))
2264
+ M = np.zeros((p, p), dtype=float)
2265
+ for j in range(p):
2266
+ cols = np.concatenate([np.arange(0, j), np.arange(j + 1, p)])
2267
+ X_minus_j = X[:, cols]
2268
+ x_j = X[:, j]
2269
+ nw = Lasso(
2270
+ alpha=lam_nw,
2271
+ fit_intercept=False,
2272
+ max_iter=500,
2273
+ tol=1e-5,
2274
+ device="cpu",
2275
+ cpu_solver="fista",
2276
+ compute_inference=False,
2277
+ )
2278
+ nw.fit(X_minus_j, x_j)
2279
+ gamma_j = nw.coef_
2280
+ z_j = x_j - X_minus_j @ gamma_j
2281
+ c_j = float(z_j @ x_j / max(n, 1))
2282
+ if abs(c_j) < 1e-30:
2283
+ M[j, j] = 1.0
2284
+ continue
2285
+ M[j, j] = 1.0 / c_j
2286
+ M[j, cols] = -gamma_j / c_j
2287
+ self._debiased_M_cpu = M
2288
+
2289
+ # Bootstrap the studentized process max_j |Z*_j|.
2290
+ param_target_idx = self._get_simultaneous_target_indices(len(self._params))
2291
+ feature_target_idx = param_target_idx - (1 if self.fit_intercept else 0)
2292
+ feature_target_idx = feature_target_idx[feature_target_idx >= 0]
2293
+ if feature_target_idx.size == 0:
2294
+ raise RuntimeError(
2295
+ "No feature coefficients selected for simultaneous inference target set."
2296
+ )
2297
+
2298
+ se_feat = np.asarray(self._bse[(1 if self.fit_intercept else 0):], dtype=float)
2299
+ eps = resid
2300
+ rng = np.random.default_rng(self.simultaneous_random_state)
2301
+ B = int(self.simultaneous_n_bootstrap)
2302
+ chunk = min(256, B)
2303
+ max_stats = np.empty(B, dtype=float)
2304
+ filled = 0
2305
+ while filled < B:
2306
+ bsz = min(chunk, B - filled)
2307
+ xi = rng.standard_normal(size=(bsz, n))
2308
+ weighted = xi * eps.reshape(1, -1)
2309
+ score = (weighted @ X) @ M.T / float(max(n, 1))
2310
+ z_star = score / (se_feat.reshape(1, -1) + 1e-30)
2311
+ max_stats[filled:filled + bsz] = np.max(
2312
+ np.abs(z_star[:, feature_target_idx]), axis=1
2313
+ )
2314
+ filled += bsz
2315
+
2316
+ critical = float(np.quantile(max_stats, 1.0 - self.simultaneous_alpha))
2317
+ params = np.asarray(self._params, dtype=float)
2318
+ bse = np.asarray(self._bse, dtype=float)
2319
+ conf_sim = np.array(self._conf_int, copy=True, dtype=float)
2320
+ conf_sim[param_target_idx, 0] = params[param_target_idx] - critical * bse[param_target_idx]
2321
+ conf_sim[param_target_idx, 1] = params[param_target_idx] + critical * bse[param_target_idx]
2322
+
2323
+ mask = np.zeros(len(params), dtype=bool)
2324
+ mask[param_target_idx] = True
2325
+ self._conf_int_simultaneous = conf_sim
2326
+ self._simultaneous_enabled = True
2327
+ self._simultaneous_method = self.simultaneous_method
2328
+ self._simultaneous_alpha = float(self.simultaneous_alpha)
2329
+ self._simultaneous_n_bootstrap = B
2330
+ self._simultaneous_critical_value = critical
2331
+ self._simultaneous_target_mask = mask
2332
+
2333
+ @property
2334
+ def rsquared(self):
2335
+ """R-squared."""
2336
+ if self._resid is None:
2337
+ # In compute_inference=False GPU mode we may avoid transferring residuals.
2338
+ if hasattr(self, "_rsquared_gpu") and self._rsquared_gpu is not None:
2339
+ return self._rsquared_gpu
2340
+ return None
2341
+ if self._y is None or self._resid is None:
2342
+ return None
2343
+ y_mean = np.mean(self._y)
2344
+ ss_tot = np.sum((self._y - y_mean) ** 2)
2345
+ ss_res = np.sum(self._resid ** 2)
2346
+ return 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
2347
+
2348
+ @property
2349
+ def rsquared_adj(self):
2350
+ """Adjusted R-squared."""
2351
+ if self._nobs is None:
2352
+ return None
2353
+ r2 = self.rsquared
2354
+ if r2 is None:
2355
+ return None
2356
+ k = len(self.coef_)
2357
+ return 1 - (1 - r2) * (self._nobs - 1) / self._df_resid
2358
+
2359
+ @property
2360
+ def fvalue(self):
2361
+ """F-statistic."""
2362
+ if self._y is not None and self._resid is not None:
2363
+ y_mean = np.mean(self._y)
2364
+ ss_tot = np.sum((self._y - y_mean) ** 2)
2365
+ ss_res = np.sum(self._resid ** 2)
2366
+ ss_reg = ss_tot - ss_res
2367
+ k = len(self.coef_)
2368
+ if k == 0 or ss_res <= 0:
2369
+ return np.inf
2370
+ return (ss_reg / k) / (ss_res / self._df_resid)
2371
+
2372
+ # GPU inference mode may skip transferring residual vectors to host.
2373
+ r2 = self.rsquared
2374
+ if r2 is None:
2375
+ return None
2376
+ k = len(self.coef_)
2377
+ if k <= 0 or self._df_resid is None or self._df_resid <= 0:
2378
+ return None
2379
+ if r2 >= 1.0:
2380
+ return np.inf
2381
+ return (r2 / k) / ((1.0 - r2) / self._df_resid)
2382
+
2383
+ @property
2384
+ def f_pvalue(self):
2385
+ """p-value for F-statistic."""
2386
+ k = len(self.coef_)
2387
+ if k <= 0 or self._df_resid is None or self._df_resid <= 0:
2388
+ return None
2389
+ fv = self.fvalue
2390
+ if fv is None:
2391
+ return None
2392
+ if fv == np.inf:
2393
+ # An infinite F-statistic corresponds to a perfect-fit / zero-residual
2394
+ # case, so the upper-tail probability tends to 0.
2395
+ return 0.0
2396
+ if fv == np.inf:
2397
+ return 0.0
2398
+ pval = 1.0 - stats.f.cdf(fv, k, self._df_resid)
2399
+ if not np.isfinite(pval):
2400
+ return None
2401
+ return float(np.clip(pval, 0.0, 1.0))
2402
+
2403
+ @property
2404
+ def aic(self):
2405
+ """Akaike Information Criterion."""
2406
+ if self._nobs is None or np.isnan(self._scale):
2407
+ return None
2408
+ return -2 * self.llf + 2 * len(self._params)
2409
+
2410
+ @property
2411
+ def bic(self):
2412
+ """Bayesian Information Criterion."""
2413
+ if self._nobs is None or np.isnan(self._scale):
2414
+ return None
2415
+ n = self._nobs
2416
+ k = len(self._params)
2417
+ return -2 * self.llf + k * np.log(n)
2418
+
2419
+ @property
2420
+ def llf(self):
2421
+ """Log-likelihood."""
2422
+ if self._nobs is None:
2423
+ return None
2424
+ n = self._nobs
2425
+ if self._resid is not None:
2426
+ sigma2_mle = np.sum(self._resid ** 2) / n
2427
+ else:
2428
+ if self._scale is None or np.isnan(self._scale):
2429
+ return None
2430
+ if self._df_resid is None or self._df_resid <= 0:
2431
+ return None
2432
+ sigma2_mle = (self._scale * self._df_resid) / n
2433
+ if sigma2_mle <= 0:
2434
+ return None
2435
+ return -n/2 * np.log(2 * np.pi * sigma2_mle) - n/2
2436
+
2437
+ def summary(self):
2438
+ """Print summary table."""
2439
+ if not self._fitted:
2440
+ raise RuntimeError("Model has not been fitted yet.")
2441
+
2442
+ if self._bse is None or self._pvalues is None or self._conf_int is None:
2443
+ raise RuntimeError(
2444
+ "compute_inference=False: inference statistics are not available. "
2445
+ "Re-fit with compute_inference=True (default) to use summary()."
2446
+ )
2447
+
2448
+ if self.fit_intercept:
2449
+ feature_names = ['(Intercept)'] + [f'x{i+1}' for i in range(len(self.coef_))]
2450
+ else:
2451
+ feature_names = [f'x{i+1}' for i in range(len(self.coef_))]
2452
+
2453
+ is_debiased = self.inference_method == "debiased"
2454
+ title = "Debiased Lasso Results" if is_debiased else "Lasso Regression Results"
2455
+ stat_label = "z" if is_debiased else "t"
2456
+ pval_label = "P>|z|" if is_debiased else "P>|t|"
2457
+
2458
+ def _fmt_stat(value, fmt_spec: str) -> str:
2459
+ if value is None:
2460
+ return f"{'nan':>15}"
2461
+ try:
2462
+ value_f = float(value)
2463
+ except Exception:
2464
+ return f"{'nan':>15}"
2465
+ if np.isnan(value_f):
2466
+ return f"{'nan':>15}"
2467
+ if np.isposinf(value_f):
2468
+ return f"{'inf':>15}"
2469
+ if np.isneginf(value_f):
2470
+ return f"{'-inf':>15}"
2471
+ return format(value_f, fmt_spec)
2472
+
2473
+ print("=" * 80)
2474
+ if self._inference_cautions:
2475
+ print("Notes:")
2476
+ for note in self._inference_cautions:
2477
+ print(f"- {note}")
2478
+ print("=" * 80)
2479
+ print(f" {title}")
2480
+ print(f" (alpha = {self.alpha:.4f})")
2481
+ print("=" * 80)
2482
+ print(f"No. Observations: {self._nobs:>15}")
2483
+ print(f"Degrees of Freedom: {self._df_resid:>15}")
2484
+ print(f"Iterations: {self.n_iter_:>15}")
2485
+ print(f"R-squared: {_fmt_stat(self.rsquared, '>15.4f')}")
2486
+ print(f"Adj. R-squared: {_fmt_stat(self.rsquared_adj, '>15.4f')}")
2487
+ print(f"F-statistic: {_fmt_stat(self.fvalue, '>15.4f')}")
2488
+ print(f"Prob (F-statistic): {_fmt_stat(self.f_pvalue, '>15.4e')}")
2489
+ print(f"Log-Likelihood: {_fmt_stat(self.llf, '>15.4f')}")
2490
+ print(f"AIC: {_fmt_stat(self.aic, '>15.4f')}")
2491
+ print(f"BIC: {_fmt_stat(self.bic, '>15.4f')}")
2492
+ print("-" * 80)
2493
+ print(f"{'':<15} {'coef':>12} {'std err':>12} {stat_label:>10} {pval_label:>10} {'[0.025':>12} {'0.975]':>12}")
2494
+ print("-" * 80)
2495
+
2496
+ for i, name in enumerate(feature_names):
2497
+ print(f"{name:<15} {self._params[i]:>12.4f} {self._bse[i]:>12.4f} "
2498
+ f"{self._tvalues[i]:>10.3f} {self._pvalues[i]:>10.4f} "
2499
+ f"{self._conf_int[i, 0]:>12.4f} {self._conf_int[i, 1]:>12.4f}")
2500
+
2501
+ if self._simultaneous_enabled and self._conf_int_simultaneous is not None:
2502
+ target_txt = (
2503
+ "include_intercept=True"
2504
+ if (self.fit_intercept and self.simultaneous_include_intercept)
2505
+ else "include_intercept=False"
2506
+ )
2507
+ print("-" * 80)
2508
+ print("Simultaneous inference")
2509
+ print(f"method: {self._simultaneous_method}")
2510
+ print(f"alpha: {self._simultaneous_alpha:.6f}")
2511
+ print(f"n_bootstrap: {self._simultaneous_n_bootstrap}")
2512
+ print(f"critical value (max|Z|): {self._simultaneous_critical_value:.6f}")
2513
+ print(f"target set: {target_txt}")
2514
+
2515
+ print("=" * 80)
2516
+
2517
+ def predict(self, X):
2518
+ """Predict using the Lasso model."""
2519
+ self._check_is_fitted()
2520
+ device = self._get_compute_device()
2521
+ if device == Device.CUDA:
2522
+ import cupy as cp
2523
+
2524
+ X_gpu = cp.asarray(self._to_array(X, Device.CUDA))
2525
+ coef_gpu = cp.asarray(self.coef_)
2526
+ intercept_gpu = cp.asarray(self.intercept_, dtype=coef_gpu.dtype)
2527
+ return X_gpu @ coef_gpu + intercept_gpu
2528
+ if device == Device.TORCH:
2529
+ import torch
2530
+
2531
+ X_torch = self._to_array(X, Device.TORCH, backend="torch").to(torch.float64)
2532
+ coef_torch = torch.as_tensor(self.coef_, dtype=X_torch.dtype, device=X_torch.device)
2533
+ intercept_torch = torch.as_tensor(
2534
+ self.intercept_, dtype=X_torch.dtype, device=X_torch.device
2535
+ )
2536
+ return X_torch @ coef_torch + intercept_torch
2537
+ X = self._to_array(X, Device.CPU)
2538
+ X = np.asarray(X)
2539
+ return X @ self.coef_ + self.intercept_
2540
+
2541
+ def score(self, X, y):
2542
+ """Return R^2 score."""
2543
+ y_pred = self.predict(X)
2544
+ device = self._get_compute_device()
2545
+ if device == Device.CUDA:
2546
+ import cupy as cp
2547
+
2548
+ yb = cp.asarray(self._to_array(y, Device.CUDA))
2549
+ ss_res = cp.sum((yb - y_pred) ** 2)
2550
+ ss_tot = cp.sum((yb - cp.mean(yb)) ** 2)
2551
+ return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
2552
+ if device == Device.TORCH:
2553
+ import torch
2554
+
2555
+ yb = self._to_array(y, Device.TORCH, backend="torch").to(y_pred.dtype)
2556
+ ss_res = torch.sum((yb - y_pred) ** 2)
2557
+ ss_tot = torch.sum((yb - torch.mean(yb)) ** 2)
2558
+ return float((1 - ss_res / ss_tot).item()) if float(ss_tot.item()) > 0 else 0.0
2559
+ y_pred = np.asarray(y_pred)
2560
+ y = self._to_numpy(y)
2561
+ ss_res = np.sum((y - y_pred) ** 2)
2562
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
2563
+ return 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
2564
+
2565
+
2566
+ def _lasso_alpha_heuristic(y_centered: np.ndarray, n_features: int) -> float:
2567
+ n_samples = int(y_centered.shape[0])
2568
+ if n_samples > 1:
2569
+ sigma_hat = float(np.std(y_centered, ddof=1))
2570
+ else:
2571
+ sigma_hat = float(np.std(y_centered))
2572
+ sigma_hat = max(sigma_hat, 1e-8)
2573
+ penalty_scale = np.sqrt(2.0 * np.log(max(2, int(n_features))) / max(1, n_samples))
2574
+ return float(sigma_hat * penalty_scale)
2575
+
2576
+
2577
+ def _default_lasso_alpha_grid(
2578
+ X: np.ndarray,
2579
+ y: np.ndarray,
2580
+ n_alphas: int = 12,
2581
+ alpha_min_ratio: float = 1e-3,
2582
+ ) -> np.ndarray:
2583
+ n_samples = int(X.shape[0])
2584
+ corr = np.abs(X.T @ y) / float(max(1, n_samples))
2585
+ alpha_max = float(np.max(corr)) if corr.size else 1.0
2586
+ alpha_max = max(alpha_max, _lasso_alpha_heuristic(y, n_features=int(X.shape[1])))
2587
+ alpha_max = max(alpha_max, 1e-6)
2588
+
2589
+ if int(n_alphas) <= 1:
2590
+ return np.asarray([alpha_max], dtype=np.float64)
2591
+
2592
+ alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
2593
+ return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
2594
+
2595
+
2596
+ def _default_lasso_alpha_grid_backend(
2597
+ X,
2598
+ y,
2599
+ backend,
2600
+ n_alphas: int = 12,
2601
+ alpha_min_ratio: float = 1e-3,
2602
+ ) -> np.ndarray:
2603
+ """Generate default alpha grid for Lasso using backend abstraction."""
2604
+ X_arr = backend.asarray(X, dtype=backend.float64)
2605
+ y_arr = backend.asarray(y, dtype=backend.float64).reshape(-1)
2606
+
2607
+ n_samples = int(X_arr.shape[0])
2608
+ corr = backend.abs(X_arr.T @ y_arr) / float(max(1, n_samples))
2609
+ # Use shape to check size - works for both numpy and torch
2610
+ corr_size = int(corr.shape[0]) if hasattr(corr, 'shape') else len(corr)
2611
+ alpha_max = float(backend.to_numpy(backend.max(corr))) if corr_size > 0 else 1.0
2612
+
2613
+ if n_samples > 1:
2614
+ y_std = backend.sqrt(backend.mean((y_arr - backend.mean(y_arr)) ** 2))
2615
+ sigma_hat = float(backend.to_numpy(y_std))
2616
+ else:
2617
+ sigma_hat = 0.0
2618
+
2619
+ sigma_hat = max(sigma_hat, 1e-8)
2620
+ penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_arr.shape[1]))) / max(1, n_samples))
2621
+ alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
2622
+
2623
+ if int(n_alphas) <= 1:
2624
+ return np.asarray([alpha_max], dtype=np.float64)
2625
+
2626
+ alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
2627
+ return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
2628
+
2629
+
2630
+ def _default_lasso_alpha_grid_cupy(
2631
+ X,
2632
+ y,
2633
+ n_alphas: int = 12,
2634
+ alpha_min_ratio: float = 1e-3,
2635
+ ) -> np.ndarray:
2636
+ import cupy as cp
2637
+
2638
+ X_cp = cp.asarray(X, dtype=cp.float64)
2639
+ y_cp = cp.asarray(y, dtype=cp.float64).reshape(-1)
2640
+
2641
+ n_samples = int(X_cp.shape[0])
2642
+ corr = cp.abs(X_cp.T @ y_cp) / float(max(1, n_samples))
2643
+ alpha_max = float(cp.max(corr).item()) if int(corr.size) > 0 else 1.0
2644
+
2645
+ if n_samples > 1:
2646
+ sigma_hat = float(cp.std(y_cp, ddof=1).item())
2647
+ else:
2648
+ sigma_hat = float(cp.std(y_cp).item())
2649
+
2650
+ sigma_hat = max(sigma_hat, 1e-8)
2651
+ penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_cp.shape[1]))) / max(1, n_samples))
2652
+ alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
2653
+
2654
+ if int(n_alphas) <= 1:
2655
+ return np.asarray([alpha_max], dtype=np.float64)
2656
+
2657
+ alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
2658
+ return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
2659
+
2660
+
2661
+ def _kfold_indices(n_samples: int, n_splits: int, random_state: Optional[int]):
2662
+ n = int(n_samples)
2663
+ k = max(2, min(int(n_splits), n))
2664
+
2665
+ rng = np.random.default_rng(random_state)
2666
+ indices = rng.permutation(n)
2667
+
2668
+ fold_sizes = np.full(k, n // k, dtype=np.int64)
2669
+ fold_sizes[: n % k] += 1
2670
+
2671
+ folds = []
2672
+ current = 0
2673
+ for fold_size in fold_sizes:
2674
+ start, stop = current, current + int(fold_size)
2675
+ val_idx = indices[start:stop]
2676
+ train_idx = np.concatenate([indices[:start], indices[stop:]])
2677
+ current = stop
2678
+ if train_idx.size == 0 or val_idx.size == 0:
2679
+ continue
2680
+ folds.append((train_idx, val_idx))
2681
+
2682
+ if len(folds) == 0:
2683
+ all_idx = np.arange(n, dtype=np.int64)
2684
+ return [(all_idx, all_idx)]
2685
+
2686
+ return folds
2687
+
2688
+
2689
+ def _normalize_cv_splits(cv_splits, n_samples: int):
2690
+ if cv_splits is None:
2691
+ return None
2692
+
2693
+ n = int(n_samples)
2694
+ folds = []
2695
+
2696
+ for split in cv_splits:
2697
+ if not isinstance(split, (tuple, list)) or len(split) != 2:
2698
+ raise ValueError("Each cv_splits entry must be a (train_idx, val_idx) pair")
2699
+
2700
+ train_idx = np.asarray(split[0], dtype=np.int64).reshape(-1)
2701
+ val_idx = np.asarray(split[1], dtype=np.int64).reshape(-1)
2702
+
2703
+ if train_idx.size == 0 or val_idx.size == 0:
2704
+ continue
2705
+
2706
+ if (
2707
+ bool(np.any(train_idx < 0))
2708
+ or bool(np.any(train_idx >= n))
2709
+ or bool(np.any(val_idx < 0))
2710
+ or bool(np.any(val_idx >= n))
2711
+ ):
2712
+ raise ValueError("cv_splits indices are out of range")
2713
+
2714
+ folds.append((train_idx, val_idx))
2715
+
2716
+ if len(folds) == 0:
2717
+ raise ValueError("cv_splits must contain at least one non-empty split")
2718
+
2719
+ return folds
2720
+
2721
+
2722
+ def _folds_are_complements(folds, n_samples: int) -> bool:
2723
+ """Return True when each fold uses train as the exact complement of validation."""
2724
+ n = int(n_samples)
2725
+ for train_idx, val_idx in folds:
2726
+ train_arr = np.asarray(train_idx, dtype=np.int64).reshape(-1)
2727
+ val_arr = np.asarray(val_idx, dtype=np.int64).reshape(-1)
2728
+
2729
+ if int(train_arr.size + val_arr.size) != n:
2730
+ return False
2731
+
2732
+ mask = np.zeros((n,), dtype=np.int8)
2733
+ mask[train_arr] = 1
2734
+ if bool(np.any(mask[val_arr] != 0)):
2735
+ return False
2736
+ mask[val_arr] = 1
2737
+ if bool(np.any(mask == 0)):
2738
+ return False
2739
+
2740
+ return True
2741
+
2742
+
2743
+ def _array_identity_token(x: Any) -> Tuple[Any, ...]:
2744
+ if x is None:
2745
+ return ("none",)
2746
+
2747
+ try:
2748
+ import cupy as cp
2749
+
2750
+ if isinstance(x, cp.ndarray):
2751
+ return ("cupy", int(x.data.ptr), tuple(int(v) for v in x.shape), str(x.dtype))
2752
+ except Exception:
2753
+ pass
2754
+
2755
+ # Check for Torch tensors
2756
+ try:
2757
+ import torch
2758
+
2759
+ if isinstance(x, torch.Tensor):
2760
+ # For GPU tensors, use the data pointer; for CPU, use storage pointer
2761
+ if x.is_cuda:
2762
+ ptr = int(x.data_ptr())
2763
+ else:
2764
+ # CPU tensor - use underlying storage pointer
2765
+ ptr = int(x.untyped_storage().data_ptr()) if hasattr(x, 'untyped_storage') else id(x)
2766
+ return ("torch", ptr, tuple(int(v) for v in x.shape), str(x.dtype))
2767
+ except Exception:
2768
+ pass
2769
+
2770
+ arr = np.asarray(x)
2771
+ ptr = int(arr.__array_interface__["data"][0]) if int(arr.size) > 0 else 0
2772
+ return ("numpy", ptr, tuple(int(v) for v in arr.shape), str(arr.dtype))
2773
+
2774
+
2775
+ def _alphas_signature(alphas: np.ndarray) -> str:
2776
+ arr = np.ascontiguousarray(np.asarray(alphas, dtype=np.float64).reshape(-1))
2777
+ return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest()
2778
+
2779
+
2780
+ def _folds_signature(folds) -> str:
2781
+ hasher = hashlib.blake2b(digest_size=16)
2782
+ for train_idx, val_idx in folds:
2783
+ train_arr = np.ascontiguousarray(np.asarray(train_idx, dtype=np.int64).reshape(-1))
2784
+ val_arr = np.ascontiguousarray(np.asarray(val_idx, dtype=np.int64).reshape(-1))
2785
+ hasher.update(train_arr.tobytes())
2786
+ hasher.update(b"|")
2787
+ hasher.update(val_arr.tobytes())
2788
+ hasher.update(b";")
2789
+ return hasher.hexdigest()
2790
+
2791
+
2792
+ def _make_lasso_cv_auto_cache_key(
2793
+ *,
2794
+ X,
2795
+ y,
2796
+ sample_weight,
2797
+ alpha_grid: np.ndarray,
2798
+ folds,
2799
+ fit_intercept: bool,
2800
+ use_gpu: bool,
2801
+ max_iter: int,
2802
+ tol: float,
2803
+ cpu_solver: str,
2804
+ cv_method: str,
2805
+ cd_kkt_check_every: Optional[int],
2806
+ gpu_cv_mixed_precision: bool,
2807
+ ) -> Tuple[Any, ...]:
2808
+ return (
2809
+ "lasso_cv_auto_v1",
2810
+ _array_identity_token(X),
2811
+ _array_identity_token(y),
2812
+ _array_identity_token(sample_weight),
2813
+ _alphas_signature(alpha_grid),
2814
+ _folds_signature(folds),
2815
+ bool(fit_intercept),
2816
+ bool(use_gpu),
2817
+ int(max_iter),
2818
+ float(tol),
2819
+ str(cpu_solver).lower(),
2820
+ str(cv_method).lower(),
2821
+ None if cd_kkt_check_every is None else int(cd_kkt_check_every),
2822
+ bool(gpu_cv_mixed_precision),
2823
+ )
2824
+
2825
+
2826
+ def _clone_lasso_cv_cache_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
2827
+ return {
2828
+ "alpha": float(payload["alpha"]),
2829
+ "alphas": np.asarray(payload["alphas"], dtype=np.float64).copy(),
2830
+ "mse_path": np.asarray(payload["mse_path"], dtype=np.float64).copy(),
2831
+ "mean_mse": np.asarray(payload["mean_mse"], dtype=np.float64).copy(),
2832
+ }
2833
+
2834
+
2835
+ def _lasso_cv_cache_get(cache_key: Optional[Tuple[Any, ...]]) -> Optional[Dict[str, Any]]:
2836
+ if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
2837
+ return None
2838
+
2839
+ cached = _LASSO_CV_ALPHA_CACHE.get(cache_key)
2840
+ if cached is None:
2841
+ return None
2842
+
2843
+ _LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
2844
+ return _clone_lasso_cv_cache_payload(cached)
2845
+
2846
+
2847
+ def _lasso_cv_cache_put(cache_key: Optional[Tuple[Any, ...]], payload: Dict[str, Any]) -> None:
2848
+ if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
2849
+ return
2850
+
2851
+ _LASSO_CV_ALPHA_CACHE[cache_key] = _clone_lasso_cv_cache_payload(payload)
2852
+ _LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
2853
+
2854
+ while len(_LASSO_CV_ALPHA_CACHE) > int(_LASSO_CV_ALPHA_CACHE_MAXSIZE):
2855
+ _LASSO_CV_ALPHA_CACHE.popitem(last=False)
2856
+
2857
+
2858
+ def _adaptive_gpu_check_every(
2859
+ *,
2860
+ base_check_every: int,
2861
+ iteration: int,
2862
+ max_iter: int,
2863
+ active_ratio: float,
2864
+ ) -> int:
2865
+ """Adaptive cadence for expensive global convergence checks on GPU."""
2866
+ base = max(1, int(base_check_every))
2867
+ ratio = float(max(0.0, min(1.0, active_ratio)))
2868
+
2869
+ if ratio >= 0.75:
2870
+ interval = max(base, 16)
2871
+ elif ratio >= 0.40:
2872
+ interval = max(base, 12)
2873
+ elif ratio >= 0.15:
2874
+ interval = max(4, base)
2875
+ else:
2876
+ interval = max(2, base // 2)
2877
+
2878
+ progress = float(iteration + 1) / float(max(1, int(max_iter)))
2879
+ if progress >= 0.90:
2880
+ interval = min(interval, 2)
2881
+ elif progress >= 0.75:
2882
+ interval = min(interval, 4)
2883
+
2884
+ return max(1, int(interval))
2885
+
2886
+
2887
+ def _soft_threshold_numpy(x: np.ndarray, gamma: float) -> np.ndarray:
2888
+ gamma_arr = np.asarray(gamma, dtype=np.float64)
2889
+ return np.sign(x) * np.maximum(np.abs(x) - gamma_arr, 0.0)
2890
+
2891
+
2892
+ def _soft_threshold_scalar(x: float, gamma: float) -> float:
2893
+ ax = abs(float(x))
2894
+ g = float(gamma)
2895
+ if ax <= g:
2896
+ return 0.0
2897
+ return float(np.sign(x) * (ax - g))
2898
+
2899
+
2900
+ if _NUMBA_AVAILABLE:
2901
+
2902
+ @njit(cache=True)
2903
+ def _soft_threshold_scalar_numba(x: float, gamma: float) -> float:
2904
+ ax = abs(x)
2905
+ if ax <= gamma:
2906
+ return 0.0
2907
+ if x >= 0.0:
2908
+ return ax - gamma
2909
+ return -(ax - gamma)
2910
+
2911
+
2912
+ @njit(cache=True)
2913
+ def _solve_lasso_path_cpu_cd_numba_impl(
2914
+ XtX: np.ndarray,
2915
+ Xty: np.ndarray,
2916
+ n_samples: int,
2917
+ alphas_desc: np.ndarray,
2918
+ max_iter: int,
2919
+ tol: float,
2920
+ stopping_is_kkt: bool,
2921
+ cd_kkt_check_every: int,
2922
+ ) -> tuple[np.ndarray, np.ndarray]:
2923
+ n_features = XtX.shape[0]
2924
+ n_alphas = alphas_desc.shape[0]
2925
+
2926
+ coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
2927
+ n_iters = np.zeros((n_alphas,), dtype=np.int32)
2928
+
2929
+ coef = np.zeros((n_features,), dtype=np.float64)
2930
+ grad = -Xty.copy()
2931
+
2932
+ X_sq_norms = np.empty((n_features,), dtype=np.float64)
2933
+ for j in range(n_features):
2934
+ X_sq_norms[j] = XtX[j, j]
2935
+
2936
+ n_samp = float(max(1, n_samples))
2937
+ alpha_scaled_desc = np.empty((n_alphas,), dtype=np.float64)
2938
+ for idx in range(n_alphas):
2939
+ alpha_scaled_desc[idx] = alphas_desc[idx] * n_samp
2940
+
2941
+ active_mask = np.zeros((n_features,), dtype=np.bool_)
2942
+ check_every = max(1, int(cd_kkt_check_every))
2943
+
2944
+ for alpha_idx in range(n_alphas):
2945
+ alpha = float(alphas_desc[alpha_idx])
2946
+ alpha_scaled = float(alpha_scaled_desc[alpha_idx])
2947
+ if alpha_idx > 0:
2948
+ prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1])
2949
+ else:
2950
+ prev_alpha_scaled = alpha_scaled
2951
+
2952
+ strong_thresh = 2.0 * alpha_scaled - prev_alpha_scaled
2953
+ if strong_thresh < 0.0:
2954
+ strong_thresh = 0.0
2955
+
2956
+ any_active = False
2957
+ max_abs_xty = -1.0
2958
+ max_abs_xty_idx = 0
2959
+ for j in range(n_features):
2960
+ abs_xty = abs(Xty[j])
2961
+ if abs_xty >= strong_thresh:
2962
+ active_mask[j] = True
2963
+ any_active = True
2964
+ if abs_xty > max_abs_xty:
2965
+ max_abs_xty = abs_xty
2966
+ max_abs_xty_idx = j
2967
+
2968
+ if not any_active:
2969
+ active_mask[max_abs_xty_idx] = True
2970
+
2971
+ converged = False
2972
+
2973
+ for iteration in range(int(max_iter)):
2974
+ coef_delta_l1 = 0.0
2975
+
2976
+ for j in range(n_features):
2977
+ if not active_mask[j]:
2978
+ continue
2979
+
2980
+ denom = float(X_sq_norms[j])
2981
+ old_val = float(coef[j])
2982
+
2983
+ if denom > 1e-10:
2984
+ rho_j = -float(grad[j]) + denom * old_val
2985
+ new_val = _soft_threshold_scalar_numba(rho_j, alpha_scaled) / denom
2986
+ else:
2987
+ new_val = 0.0
2988
+
2989
+ delta = new_val - old_val
2990
+ if delta != 0.0:
2991
+ coef[j] = new_val
2992
+ coef_delta_l1 += abs(delta)
2993
+ for row_idx in range(n_features):
2994
+ grad[row_idx] += XtX[row_idx, j] * delta
2995
+
2996
+ should_kkt_scan = (
2997
+ ((iteration + 1) % check_every == 0)
2998
+ or (coef_delta_l1 < float(tol))
2999
+ or (iteration + 1 == int(max_iter))
3000
+ )
3001
+
3002
+ violation = 0.0
3003
+ has_inactive_violation = False
3004
+
3005
+ if should_kkt_scan:
3006
+ for j in range(n_features):
3007
+ v = abs(grad[j] / n_samp) - alpha
3008
+ if v < 0.0:
3009
+ v = 0.0
3010
+ if v > violation:
3011
+ violation = v
3012
+ if v > float(tol) and (not active_mask[j]):
3013
+ active_mask[j] = True
3014
+ has_inactive_violation = True
3015
+
3016
+ if stopping_is_kkt:
3017
+ if should_kkt_scan and violation < float(tol):
3018
+ n_iters[alpha_idx] = int(iteration) + 1
3019
+ converged = True
3020
+ break
3021
+ else:
3022
+ if coef_delta_l1 < float(tol) and (not has_inactive_violation):
3023
+ n_iters[alpha_idx] = int(iteration) + 1
3024
+ converged = True
3025
+ break
3026
+
3027
+ if not converged:
3028
+ n_iters[alpha_idx] = int(max_iter)
3029
+
3030
+ for j in range(n_features):
3031
+ coefs_path[alpha_idx, j] = coef[j]
3032
+ if abs(coef[j]) > 0.0:
3033
+ active_mask[j] = True
3034
+
3035
+ return coefs_path, n_iters
3036
+
3037
+
3038
+ def _solve_lasso_path_cpu_cd_numba(
3039
+ XtX: np.ndarray,
3040
+ Xty: np.ndarray,
3041
+ *,
3042
+ n_samples: int,
3043
+ alphas_desc: np.ndarray,
3044
+ max_iter: int,
3045
+ tol: float,
3046
+ stopping: str,
3047
+ cd_kkt_check_every: int,
3048
+ ) -> tuple[np.ndarray, np.ndarray]:
3049
+ XtX_c = np.ascontiguousarray(XtX, dtype=np.float64)
3050
+ Xty_c = np.ascontiguousarray(Xty, dtype=np.float64)
3051
+ alphas_c = np.ascontiguousarray(np.asarray(alphas_desc, dtype=np.float64))
3052
+ stopping_is_kkt = str(stopping).lower() == "kkt"
3053
+ return _solve_lasso_path_cpu_cd_numba_impl(
3054
+ XtX_c,
3055
+ Xty_c,
3056
+ int(n_samples),
3057
+ alphas_c,
3058
+ int(max_iter),
3059
+ float(tol),
3060
+ bool(stopping_is_kkt),
3061
+ int(cd_kkt_check_every),
3062
+ )
3063
+
3064
+
3065
+ def _normalize_lassocv_method(method: str) -> str:
3066
+ """Normalize CV optimization profile name."""
3067
+ key = str(method).strip().lower()
3068
+ alias_map = {
3069
+ "default": "standard",
3070
+ "classic": "standard",
3071
+ "glmnet_cv": "glmnet",
3072
+ "glmnet.cv": "glmnet",
3073
+ }
3074
+ key = alias_map.get(key, key)
3075
+ if key not in ("standard", "glmnet"):
3076
+ raise ValueError("method must be one of: 'standard', 'glmnet'")
3077
+ return key
3078
+
3079
+
3080
+ def _normalize_cd_kkt_check_every(cd_kkt_check_every: Optional[int]) -> Optional[int]:
3081
+ """Validate optional coordinate-descent global KKT scan cadence."""
3082
+ if cd_kkt_check_every is None:
3083
+ return None
3084
+ value = int(cd_kkt_check_every)
3085
+ if value <= 0:
3086
+ raise ValueError("cd_kkt_check_every must be a positive integer or None")
3087
+ return value
3088
+
3089
+
3090
+ def _solve_lasso_path_cpu_fista_batched_from_gram(
3091
+ XtX: np.ndarray,
3092
+ Xty: np.ndarray,
3093
+ *,
3094
+ n_samples: int,
3095
+ alphas_desc: np.ndarray,
3096
+ max_iter: int,
3097
+ tol: float,
3098
+ stopping: str,
3099
+ lipschitz_L: Optional[float] = None,
3100
+ check_every: int = 2,
3101
+ ) -> tuple[np.ndarray, np.ndarray]:
3102
+ """Solve descending-alpha Lasso path with a batched CPU FISTA update."""
3103
+ n_features = int(XtX.shape[0])
3104
+ n_alphas = int(alphas_desc.shape[0])
3105
+
3106
+ coefs = np.zeros((n_features, n_alphas), dtype=np.float64)
3107
+ yk = coefs.copy()
3108
+ tk = np.ones((n_alphas,), dtype=np.float64)
3109
+ n_iters = np.zeros((n_alphas,), dtype=np.int32)
3110
+
3111
+ if lipschitz_L is not None:
3112
+ L = float(lipschitz_L)
3113
+ else:
3114
+ try:
3115
+ eigvals = np.linalg.eigvalsh(XtX)
3116
+ L = float(eigvals[-1] / float(max(1, n_samples)))
3117
+ except Exception:
3118
+ row_sum_bound = float(np.max(np.sum(np.abs(XtX), axis=1)) / float(max(1, n_samples)))
3119
+ L = max(row_sum_bound, 1e-12)
3120
+
3121
+ if L <= 0.0:
3122
+ return coefs.T, n_iters
3123
+
3124
+ n_samp = float(max(1, n_samples))
3125
+ step = 1.0 / L
3126
+ alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
3127
+ thresholds = alphas_desc * step
3128
+ stopping_name = str(stopping).lower()
3129
+ check_every = max(1, int(check_every))
3130
+
3131
+ active = np.arange(n_alphas, dtype=np.int64)
3132
+
3133
+ for iteration in range(int(max_iter)):
3134
+ if active.size == 0:
3135
+ break
3136
+
3137
+ y_active = yk[:, active]
3138
+ coef_old = coefs[:, active]
3139
+
3140
+ grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
3141
+ thresh = thresholds[active].reshape(1, -1)
3142
+ coef_new = _soft_threshold_numpy(y_active - step * grad, thresh)
3143
+
3144
+ t_old = tk[active]
3145
+ t_new = (1.0 + np.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
3146
+ beta = (t_old - 1.0) / t_new
3147
+ y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
3148
+
3149
+ coefs[:, active] = coef_new
3150
+ yk[:, active] = y_new
3151
+ tk[active] = t_new
3152
+
3153
+ should_check = ((iteration + 1) % check_every == 0) or (iteration + 1 == int(max_iter))
3154
+ if not should_check:
3155
+ continue
3156
+
3157
+ if stopping_name == "kkt":
3158
+ grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
3159
+ viol = np.max(
3160
+ np.maximum(
3161
+ np.abs(grad_sse) - alphas_desc[active].reshape(1, -1),
3162
+ 0.0,
3163
+ ),
3164
+ axis=0,
3165
+ )
3166
+ converged_local = viol < float(tol)
3167
+ else:
3168
+ delta = np.sum(np.abs(coef_new - coef_old), axis=0)
3169
+ converged_local = delta < float(tol)
3170
+
3171
+ if not np.any(converged_local):
3172
+ continue
3173
+
3174
+ done = active[converged_local]
3175
+ n_iters[done] = int(iteration) + 1
3176
+ yk[:, done] = coefs[:, done]
3177
+ active = active[~converged_local]
3178
+
3179
+ if active.size > 0:
3180
+ n_iters[active] = int(max_iter)
3181
+
3182
+ return coefs.T, n_iters
3183
+
3184
+
3185
+ def _solve_lasso_path_gpu_fista_batched_from_gram(
3186
+ XtX,
3187
+ Xty,
3188
+ *,
3189
+ n_samples: int,
3190
+ alphas_desc: np.ndarray,
3191
+ max_iter: int,
3192
+ tol: float,
3193
+ stopping: str,
3194
+ lipschitz_L: Optional[float] = None,
3195
+ check_every: int = 8,
3196
+ ):
3197
+ """Solve descending-alpha Lasso path with a batched GPU FISTA update."""
3198
+ import cupy as cp
3199
+
3200
+ n_features = int(XtX.shape[0])
3201
+ n_alphas = int(alphas_desc.shape[0])
3202
+
3203
+ coefs = cp.zeros((n_features, n_alphas), dtype=XtX.dtype)
3204
+ yk = coefs.copy()
3205
+ tk = cp.ones((n_alphas,), dtype=XtX.dtype)
3206
+ n_iters_gpu = cp.zeros((n_alphas,), dtype=cp.int32)
3207
+
3208
+ if lipschitz_L is not None:
3209
+ L = cp.array(float(lipschitz_L), dtype=XtX.dtype)
3210
+ else:
3211
+ try:
3212
+ eigvals = cp.linalg.eigvalsh(XtX)
3213
+ L = eigvals[-1] / float(max(1, n_samples))
3214
+ except Exception:
3215
+ row_sum_bound = cp.max(cp.sum(cp.abs(XtX), axis=1)) / float(max(1, n_samples))
3216
+ L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX.dtype))
3217
+
3218
+ L_scalar = float(cp.asnumpy(L))
3219
+ if L_scalar <= 0.0:
3220
+ return coefs.T, np.zeros((n_alphas,), dtype=np.int32)
3221
+
3222
+ n_samp = float(max(1, n_samples))
3223
+ step = 1.0 / L
3224
+ alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
3225
+ alpha_gpu = cp.asarray(alphas_desc, dtype=XtX.dtype)
3226
+ thresholds = alpha_gpu * step
3227
+ stopping_name = str(stopping).lower()
3228
+ check_every = max(1, int(check_every))
3229
+
3230
+ active_gpu = cp.arange(n_alphas, dtype=cp.int32)
3231
+
3232
+ for iteration in range(int(max_iter)):
3233
+ if int(active_gpu.size) == 0:
3234
+ break
3235
+
3236
+ y_active = yk[:, active_gpu]
3237
+ coef_old = coefs[:, active_gpu]
3238
+
3239
+ grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
3240
+ thresh = thresholds[active_gpu].reshape(1, -1)
3241
+ coef_new = cp.sign(y_active - step * grad) * cp.maximum(cp.abs(y_active - step * grad) - thresh, 0.0)
3242
+
3243
+ t_old = tk[active_gpu]
3244
+ t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
3245
+ beta = (t_old - 1.0) / t_new
3246
+ y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
3247
+
3248
+ coefs[:, active_gpu] = coef_new
3249
+ yk[:, active_gpu] = y_new
3250
+ tk[active_gpu] = t_new
3251
+
3252
+ active_ratio = float(int(active_gpu.size)) / float(max(1, n_alphas))
3253
+ check_every_eff = _adaptive_gpu_check_every(
3254
+ base_check_every=check_every,
3255
+ iteration=iteration,
3256
+ max_iter=int(max_iter),
3257
+ active_ratio=active_ratio,
3258
+ )
3259
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
3260
+ if not should_check:
3261
+ continue
3262
+
3263
+ if stopping_name == "kkt":
3264
+ grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
3265
+ viol = cp.max(
3266
+ cp.maximum(
3267
+ cp.abs(grad_sse) - alpha_gpu[active_gpu].reshape(1, -1),
3268
+ 0.0,
3269
+ ),
3270
+ axis=0,
3271
+ )
3272
+ converged_local_gpu = viol < float(tol)
3273
+ else:
3274
+ delta = cp.sum(cp.abs(coef_new - coef_old), axis=0)
3275
+ converged_local_gpu = delta < float(tol)
3276
+
3277
+ done_gpu = active_gpu[converged_local_gpu]
3278
+ if int(done_gpu.size) == 0:
3279
+ continue
3280
+
3281
+ n_iters_gpu[done_gpu] = int(iteration) + 1
3282
+ yk[:, done_gpu] = coefs[:, done_gpu]
3283
+ active_gpu = active_gpu[~converged_local_gpu]
3284
+
3285
+ if int(active_gpu.size) > 0:
3286
+ n_iters_gpu[active_gpu] = int(max_iter)
3287
+
3288
+ return coefs.T, cp.asnumpy(n_iters_gpu)
3289
+
3290
+
3291
+ def _solve_lasso_path_gpu_fista_multi_fold_from_gram(
3292
+ XtX_batch,
3293
+ Xty_batch,
3294
+ *,
3295
+ n_samples_vec,
3296
+ alphas_desc,
3297
+ max_iter: int,
3298
+ tol: float,
3299
+ stopping: str,
3300
+ lipschitz_L: Optional[float] = None,
3301
+ check_every: int = 8,
3302
+ ):
3303
+ """Solve descending-alpha Lasso paths for all folds together on GPU.
3304
+
3305
+ Note: Fused kernel optimization is disabled for multi-fold solver due to
3306
+ dtype complexity. The single-fold Lasso solver uses fused kernels.
3307
+ """
3308
+ import cupy as cp
3309
+
3310
+ n_folds = int(XtX_batch.shape[0])
3311
+ n_features = int(XtX_batch.shape[1])
3312
+ n_alphas = int(alphas_desc.shape[0])
3313
+
3314
+ coefs = cp.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype)
3315
+ yk = coefs.copy()
3316
+ tk = cp.ones((n_folds, n_alphas), dtype=XtX_batch.dtype)
3317
+ n_iters_gpu = cp.zeros((n_folds, n_alphas), dtype=cp.int32)
3318
+
3319
+ # Convert n_samples_vec to numpy using .get() if it's a CuPy array
3320
+ if hasattr(n_samples_vec, 'get'):
3321
+ n_vec_cpu = n_samples_vec.get().astype(np.float64).reshape(-1)
3322
+ else:
3323
+ n_vec_cpu = np.asarray(n_samples_vec, dtype=np.float64).reshape(-1)
3324
+ if n_vec_cpu.size != n_folds:
3325
+ raise ValueError("n_samples_vec must have one entry per fold")
3326
+ n_vec = cp.asarray(n_vec_cpu, dtype=XtX_batch.dtype)
3327
+
3328
+ if lipschitz_L is not None:
3329
+ L = cp.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype)
3330
+ else:
3331
+ try:
3332
+ eigvals = cp.linalg.eigvalsh(XtX_batch)
3333
+ L = eigvals[:, -1] / n_vec
3334
+ except Exception:
3335
+ row_sum_bound = cp.max(cp.sum(cp.abs(XtX_batch), axis=2), axis=1) / n_vec
3336
+ L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX_batch.dtype))
3337
+
3338
+ step = 1.0 / L.reshape(n_folds, 1, 1)
3339
+ # Convert alphas_desc to numpy using .get() if it's a CuPy array
3340
+ if hasattr(alphas_desc, 'get'):
3341
+ alphas_cpu = alphas_desc.get().astype(np.float64)
3342
+ else:
3343
+ alphas_cpu = np.asarray(alphas_desc, dtype=np.float64)
3344
+ alpha_gpu = cp.asarray(alphas_cpu, dtype=XtX_batch.dtype).reshape(1, 1, n_alphas)
3345
+ thresholds = alpha_gpu * step
3346
+
3347
+ Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
3348
+ n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
3349
+ stopping_name = str(stopping).lower()
3350
+ check_every = max(1, int(check_every))
3351
+
3352
+ active_gpu = cp.ones((n_folds, n_alphas), dtype=cp.bool_)
3353
+ active_count = int(n_folds * n_alphas)
3354
+
3355
+ # Note: Fused kernels disabled for multi-fold solver due to dtype complexity
3356
+ # The single-fold Lasso._fit_gpu uses fused kernels
3357
+ use_fused = False
3358
+ fused = None
3359
+
3360
+ for iteration in range(int(max_iter)):
3361
+ if active_count == 0:
3362
+ break
3363
+
3364
+ active_expanded = active_gpu[:, cp.newaxis, :]
3365
+
3366
+ coef_old = coefs.copy()
3367
+ grad = (cp.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
3368
+
3369
+ # Proximal step: soft thresholding
3370
+ yk_step = yk - step * grad
3371
+ coef_candidate = cp.sign(yk_step) * cp.maximum(cp.abs(yk_step) - thresholds, 0.0)
3372
+ coefs = cp.where(active_expanded, coef_candidate, coefs)
3373
+
3374
+ t_old = tk
3375
+ t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
3376
+ beta = (t_old - 1.0) / t_new
3377
+ y_candidate = coefs + beta[:, cp.newaxis, :] * (coefs - coef_old)
3378
+ yk = cp.where(active_expanded, y_candidate, yk)
3379
+ tk = cp.where(active_gpu, t_new, tk)
3380
+
3381
+ active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
3382
+ check_every_eff = _adaptive_gpu_check_every(
3383
+ base_check_every=check_every,
3384
+ iteration=iteration,
3385
+ max_iter=int(max_iter),
3386
+ active_ratio=active_ratio,
3387
+ )
3388
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
3389
+ if not should_check:
3390
+ continue
3391
+
3392
+ if stopping_name == "kkt":
3393
+ grad_sse = (cp.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
3394
+ violation = cp.max(cp.maximum(cp.abs(grad_sse) - alpha_gpu, 0.0), axis=1)
3395
+ converged_local_gpu = violation < float(tol)
3396
+ else:
3397
+ delta = cp.sum(cp.abs(coefs - coef_old), axis=1)
3398
+ converged_local_gpu = delta < float(tol)
3399
+
3400
+ newly_done_gpu = active_gpu & converged_local_gpu
3401
+ done_count = int(cp.count_nonzero(newly_done_gpu).item())
3402
+ if done_count == 0:
3403
+ continue
3404
+
3405
+ n_iters_gpu[newly_done_gpu] = int(iteration) + 1
3406
+ yk = cp.where(newly_done_gpu[:, cp.newaxis, :], coefs, yk)
3407
+ active_gpu = active_gpu & (~converged_local_gpu)
3408
+ active_count -= done_count
3409
+
3410
+ n_iters_gpu[active_gpu] = int(max_iter)
3411
+
3412
+ return cp.transpose(coefs, (0, 2, 1)), cp.asnumpy(n_iters_gpu)
3413
+
3414
+
3415
+ def _solve_lasso_path_cpu_from_gram(
3416
+ XtX: np.ndarray,
3417
+ Xty: np.ndarray,
3418
+ *,
3419
+ n_samples: int,
3420
+ alphas_desc: np.ndarray,
3421
+ max_iter: int,
3422
+ tol: float,
3423
+ stopping: str,
3424
+ cpu_solver: str,
3425
+ lipschitz_L: Optional[float] = None,
3426
+ cd_kkt_check_every: int = 1,
3427
+ ) -> tuple[np.ndarray, np.ndarray]:
3428
+ """Solve a descending-alpha Lasso path on CPU using one precomputed Gram matrix."""
3429
+ n_features = int(XtX.shape[0])
3430
+ n_alphas = int(alphas_desc.shape[0])
3431
+
3432
+ coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
3433
+ n_iters = np.zeros(n_alphas, dtype=np.int32)
3434
+
3435
+ coef = np.zeros(n_features, dtype=np.float64)
3436
+ stopping_name = str(stopping).lower()
3437
+ solver_name = str(cpu_solver).lower()
3438
+
3439
+ if solver_name == "fista":
3440
+ return _solve_lasso_path_cpu_fista_batched_from_gram(
3441
+ XtX,
3442
+ Xty,
3443
+ n_samples=n_samples,
3444
+ alphas_desc=alphas_desc,
3445
+ max_iter=max_iter,
3446
+ tol=tol,
3447
+ stopping=stopping,
3448
+ lipschitz_L=lipschitz_L,
3449
+ check_every=2,
3450
+ )
3451
+
3452
+ global _NUMBA_CD_DISABLED
3453
+ use_numba_cd = (
3454
+ _NUMBA_AVAILABLE
3455
+ and (not _NUMBA_CD_DISABLED)
3456
+ and solver_name == "coordinate_descent"
3457
+ )
3458
+
3459
+ if use_numba_cd:
3460
+ try:
3461
+ return _solve_lasso_path_cpu_cd_numba(
3462
+ XtX,
3463
+ Xty,
3464
+ n_samples=n_samples,
3465
+ alphas_desc=alphas_desc,
3466
+ max_iter=max_iter,
3467
+ tol=tol,
3468
+ stopping=stopping,
3469
+ cd_kkt_check_every=cd_kkt_check_every,
3470
+ )
3471
+ except Exception:
3472
+ _NUMBA_CD_DISABLED = True
3473
+
3474
+ # Coordinate descent with incremental gradient updates.
3475
+ X_sq_norms = np.diag(XtX).astype(np.float64, copy=False)
3476
+ grad = XtX @ coef - Xty
3477
+ alpha_scaled_desc = np.asarray(alphas_desc, dtype=np.float64) * float(max(1, n_samples))
3478
+ active_mask = np.zeros((n_features,), dtype=bool)
3479
+ cd_kkt_check_every = max(1, int(cd_kkt_check_every))
3480
+
3481
+ for alpha_idx, alpha in enumerate(alphas_desc):
3482
+ alpha_scaled = float(alpha_scaled_desc[alpha_idx])
3483
+ prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1]) if alpha_idx > 0 else alpha_scaled
3484
+
3485
+ # Strong rule screening: expand active set before cyclic updates.
3486
+ strong_thresh = max(0.0, 2.0 * alpha_scaled - prev_alpha_scaled)
3487
+ active_mask |= np.abs(Xty) >= strong_thresh
3488
+ if not bool(np.any(active_mask)):
3489
+ active_mask[int(np.argmax(np.abs(Xty)))] = True
3490
+
3491
+ converged = False
3492
+
3493
+ for iteration in range(int(max_iter)):
3494
+ coef_delta_l1 = 0.0
3495
+
3496
+ active_idx = np.flatnonzero(active_mask)
3497
+ for j in active_idx:
3498
+ denom = float(X_sq_norms[j])
3499
+ old_val = float(coef[j])
3500
+
3501
+ if denom > 1e-10:
3502
+ rho_j = -float(grad[j]) + denom * old_val
3503
+ new_val = _soft_threshold_scalar(rho_j, alpha_scaled) / denom
3504
+ else:
3505
+ new_val = 0.0
3506
+
3507
+ delta = new_val - old_val
3508
+ if abs(delta) > 0.0:
3509
+ coef[j] = new_val
3510
+ grad += XtX[:, j] * delta
3511
+ coef_delta_l1 += abs(delta)
3512
+
3513
+ # glmnet-style optimization can skip full inactive KKT scans on every pass,
3514
+ # then force a check when updates become small.
3515
+ should_kkt_scan = (
3516
+ ((iteration + 1) % cd_kkt_check_every == 0)
3517
+ or (coef_delta_l1 < float(tol))
3518
+ or (iteration + 1 == int(max_iter))
3519
+ )
3520
+ violation = float("inf")
3521
+ inactive_violation_idx = np.empty((0,), dtype=np.int64)
3522
+
3523
+ if should_kkt_scan:
3524
+ violation_vec = np.maximum(
3525
+ np.abs(grad / float(max(1, n_samples))) - float(alpha),
3526
+ 0.0,
3527
+ )
3528
+ inactive_violation_idx = np.where((violation_vec > float(tol)) & (~active_mask))[0]
3529
+ if inactive_violation_idx.size > 0:
3530
+ active_mask[inactive_violation_idx] = True
3531
+ violation = float(np.max(violation_vec))
3532
+
3533
+ if stopping_name == "kkt":
3534
+ if should_kkt_scan and violation < float(tol):
3535
+ n_iters[alpha_idx] = iteration + 1
3536
+ converged = True
3537
+ break
3538
+ else:
3539
+ if coef_delta_l1 < float(tol) and inactive_violation_idx.size == 0:
3540
+ n_iters[alpha_idx] = iteration + 1
3541
+ converged = True
3542
+ break
3543
+
3544
+ if not converged:
3545
+ n_iters[alpha_idx] = int(max_iter)
3546
+
3547
+ coefs_path[alpha_idx, :] = coef
3548
+ active_mask |= np.abs(coef) > 0.0
3549
+
3550
+ return coefs_path, n_iters
3551
+
3552
+
3553
+ def _solve_lasso_path_gpu_from_gram(
3554
+ XtX,
3555
+ Xty,
3556
+ *,
3557
+ n_samples: int,
3558
+ alphas_desc: np.ndarray,
3559
+ max_iter: int,
3560
+ tol: float,
3561
+ stopping: str,
3562
+ lipschitz_L: Optional[float] = None,
3563
+ check_every: int = 8,
3564
+ ):
3565
+ """Solve a descending-alpha Lasso path on GPU using one precomputed Gram matrix."""
3566
+ return _solve_lasso_path_gpu_fista_batched_from_gram(
3567
+ XtX,
3568
+ Xty,
3569
+ n_samples=n_samples,
3570
+ alphas_desc=alphas_desc,
3571
+ max_iter=max_iter,
3572
+ tol=tol,
3573
+ stopping=stopping,
3574
+ lipschitz_L=lipschitz_L,
3575
+ check_every=check_every,
3576
+ )
3577
+
3578
+
3579
+ def _batch_mse_numpy(
3580
+ X_val: np.ndarray,
3581
+ y_val: np.ndarray,
3582
+ coefs_path: np.ndarray,
3583
+ intercepts_path: np.ndarray,
3584
+ sample_weight_val: Optional[np.ndarray],
3585
+ ) -> np.ndarray:
3586
+ preds = X_val @ coefs_path.T + intercepts_path.reshape(1, -1)
3587
+ sq_err = (y_val.reshape(-1, 1) - preds) ** 2
3588
+
3589
+ if sample_weight_val is None:
3590
+ return np.mean(sq_err, axis=0)
3591
+
3592
+ denom = float(np.sum(sample_weight_val))
3593
+ if denom <= 0.0:
3594
+ return np.mean(sq_err, axis=0)
3595
+
3596
+ return np.sum(sample_weight_val.reshape(-1, 1) * sq_err, axis=0) / denom
3597
+
3598
+
3599
+ def _batch_mse(
3600
+ X_val,
3601
+ y_val,
3602
+ coefs_path,
3603
+ intercepts_path,
3604
+ backend,
3605
+ sample_weight_val,
3606
+ ) -> np.ndarray:
3607
+ """
3608
+ Compute MSE for multiple coefficient vectors.
3609
+
3610
+ Parameters
3611
+ ----------
3612
+ X_val : array-like
3613
+ Validation design matrix.
3614
+ y_val : array-like
3615
+ Validation response.
3616
+ coefs_path : array-like
3617
+ Coefficient matrix (n_alphas, n_features).
3618
+ intercepts_path : array-like
3619
+ Intercept vector (n_alphas,).
3620
+ backend : BackendBase
3621
+ Backend instance (CuPyBackend or TorchBackend).
3622
+ sample_weight_val : array-like or None
3623
+ Sample weights.
3624
+
3625
+ Returns
3626
+ -------
3627
+ mse : ndarray
3628
+ MSE for each alpha.
3629
+ """
3630
+ preds = X_val @ coefs_path.T + intercepts_path.reshape(1, -1)
3631
+ sq_err = (y_val.reshape(-1, 1) - preds) ** 2
3632
+
3633
+ if sample_weight_val is None:
3634
+ mse = backend.mean(sq_err, axis=0)
3635
+ else:
3636
+ denom = backend.sum(sample_weight_val)
3637
+ if float(backend.to_numpy(denom)) <= 0.0:
3638
+ mse = backend.mean(sq_err, axis=0)
3639
+ else:
3640
+ mse = backend.sum(sample_weight_val.reshape(-1, 1) * sq_err, axis=0) / denom
3641
+
3642
+ return backend.to_numpy(mse)
3643
+
3644
+
3645
+ def _soft_threshold_torch(x, gamma):
3646
+ """Soft thresholding operator for Torch tensors."""
3647
+ import torch
3648
+ return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, torch.tensor(0.0, dtype=x.dtype, device=x.device))
3649
+
3650
+
3651
+ def _fit_lasso_single_alpha_fast(
3652
+ X,
3653
+ y,
3654
+ *,
3655
+ alpha: float,
3656
+ fit_intercept: bool,
3657
+ max_iter: int,
3658
+ tol: float,
3659
+ stopping: str,
3660
+ device: str,
3661
+ cpu_solver: str,
3662
+ cd_kkt_check_every: int = 1,
3663
+ sample_weight=None,
3664
+ ) -> Dict[str, object]:
3665
+ """Fast single-alpha Lasso fit using optimized Gram-based path solvers."""
3666
+ device_name = str(device).lower()
3667
+ alpha_vec = np.asarray([float(alpha)], dtype=np.float64)
3668
+
3669
+ # Check if inputs are torch tensors on GPU
3670
+ is_torch_gpu = False
3671
+ try:
3672
+ import torch
3673
+ is_torch_gpu = device_name == Device.CUDA.value and isinstance(X, torch.Tensor)
3674
+ except Exception:
3675
+ pass
3676
+
3677
+ if device_name == Device.CUDA.value and not is_torch_gpu:
3678
+ # CuPy GPU path
3679
+ import cupy as cp
3680
+
3681
+ X_arr = cp.asarray(X)
3682
+ y_arr = cp.asarray(y).reshape(-1)
3683
+
3684
+ if sample_weight is not None:
3685
+ sw = cp.asarray(sample_weight)
3686
+ sqrt_sw = cp.sqrt(sw)
3687
+ X_arr = X_arr * sqrt_sw[:, cp.newaxis]
3688
+ y_arr = y_arr * sqrt_sw
3689
+
3690
+ if bool(fit_intercept):
3691
+ X_mean = cp.mean(X_arr, axis=0)
3692
+ y_mean = cp.mean(y_arr)
3693
+ X_centered = X_arr - X_mean
3694
+ y_centered = y_arr - y_mean
3695
+ else:
3696
+ X_mean = cp.zeros((X_arr.shape[1],), dtype=X_arr.dtype)
3697
+ y_mean = cp.array(0.0, dtype=X_arr.dtype)
3698
+ X_centered = X_arr
3699
+ y_centered = y_arr
3700
+
3701
+ XtX = X_centered.T @ X_centered
3702
+ Xty = X_centered.T @ y_centered
3703
+
3704
+ coefs_desc, n_iters = _solve_lasso_path_gpu_from_gram(
3705
+ XtX,
3706
+ Xty,
3707
+ n_samples=int(X_arr.shape[0]),
3708
+ alphas_desc=alpha_vec,
3709
+ max_iter=int(max_iter),
3710
+ tol=float(tol),
3711
+ stopping=str(stopping),
3712
+ lipschitz_L=None,
3713
+ check_every=8,
3714
+ )
3715
+
3716
+ coef_gpu = coefs_desc[0]
3717
+ if bool(fit_intercept):
3718
+ intercept_gpu = y_mean - X_mean @ coef_gpu
3719
+ intercept = float(cp.asnumpy(intercept_gpu))
3720
+ else:
3721
+ intercept = 0.0
3722
+
3723
+ coef = np.asarray(cp.asnumpy(coef_gpu), dtype=np.float64)
3724
+ return {
3725
+ "coef": coef,
3726
+ "intercept": float(intercept),
3727
+ "n_iter": int(n_iters[0]),
3728
+ "n_samples": int(X_arr.shape[0]),
3729
+ "n_features": int(X_arr.shape[1]),
3730
+ }
3731
+
3732
+ elif is_torch_gpu:
3733
+ # Torch GPU path - use FISTA solver directly on GPU tensors
3734
+ import torch
3735
+
3736
+ X_arr = X
3737
+ y_arr = y.reshape(-1) if isinstance(y, torch.Tensor) else torch.as_tensor(
3738
+ y, dtype=X_arr.dtype, device=X_arr.device
3739
+ ).reshape(-1)
3740
+
3741
+ if sample_weight is not None:
3742
+ sw = sample_weight if isinstance(sample_weight, torch.Tensor) else torch.as_tensor(
3743
+ sample_weight, dtype=X_arr.dtype, device=X_arr.device
3744
+ )
3745
+ sqrt_sw = torch.sqrt(sw)
3746
+ X_arr = X_arr * sqrt_sw[:, None]
3747
+ y_arr = y_arr * sqrt_sw
3748
+
3749
+ if bool(fit_intercept):
3750
+ X_mean = torch.mean(X_arr, dim=0)
3751
+ y_mean = torch.mean(y_arr)
3752
+ X_centered = X_arr - X_mean
3753
+ y_centered = y_arr - y_mean
3754
+ else:
3755
+ X_mean = torch.zeros((X_arr.shape[1],), dtype=X_arr.dtype, device=X_arr.device)
3756
+ y_mean = torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)
3757
+ X_centered = X_arr
3758
+ y_centered = y_arr
3759
+
3760
+ n_samples = int(X_arr.shape[0])
3761
+ n_features = int(X_arr.shape[1])
3762
+
3763
+ # Precompute Gram matrix and X'y for FISTA gradient
3764
+ XtX = X_centered.T @ X_centered
3765
+ Xty = X_centered.T @ y_centered
3766
+
3767
+ # Compute Lipschitz constant L = max eigenvalue of XtX / n
3768
+ try:
3769
+ eigvals = torch.linalg.eigvalsh(XtX)
3770
+ L = eigvals[-1] / n_samples
3771
+ except Exception:
3772
+ L = torch.sum(X_centered ** 2) / n_samples
3773
+ L = torch.clamp(L, min=1e-10)
3774
+
3775
+ step = 1.0 / L
3776
+ thresh = float(alpha) * step
3777
+
3778
+ # FISTA initialization
3779
+ coef = torch.zeros(n_features, dtype=X_arr.dtype, device=X_arr.device)
3780
+ z = coef.clone()
3781
+ t = torch.tensor(1.0, dtype=X_arr.dtype, device=X_arr.device)
3782
+
3783
+ # FISTA iterations
3784
+ for iteration in range(int(max_iter)):
3785
+ coef_old = coef.clone()
3786
+
3787
+ # Gradient step at z
3788
+ grad = (XtX @ z - Xty) / n_samples
3789
+ coef = _soft_threshold_torch(z - step * grad, thresh)
3790
+
3791
+ # Momentum update
3792
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * t ** 2)) / 2.0
3793
+ z = coef + ((t - 1.0) / t_new) * (coef - coef_old)
3794
+ t = t_new
3795
+
3796
+ # Convergence check
3797
+ if str(stopping).lower() == "kkt":
3798
+ grad_sse = (XtX @ coef - Xty) / n_samples
3799
+ violation = torch.max(torch.maximum(torch.abs(grad_sse) - float(alpha), torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)))
3800
+ if violation < float(tol):
3801
+ break
3802
+ else:
3803
+ if torch.sum(torch.abs(coef - coef_old)) < float(tol):
3804
+ break
3805
+
3806
+ # Build coefficients
3807
+ if bool(fit_intercept):
3808
+ intercept_torch = y_mean - X_mean @ coef
3809
+ intercept = float(intercept_torch.item())
3810
+ else:
3811
+ intercept = 0.0
3812
+
3813
+ coef_np = np.asarray(coef.detach().cpu().numpy(), dtype=np.float64)
3814
+ return {
3815
+ "coef": coef_np,
3816
+ "intercept": float(intercept),
3817
+ "n_iter": int(iteration + 1),
3818
+ "n_samples": n_samples,
3819
+ "n_features": n_features,
3820
+ }
3821
+
3822
+ X_arr = np.asarray(X)
3823
+ y_arr = np.asarray(y).reshape(-1)
3824
+
3825
+ if sample_weight is not None:
3826
+ sw = np.asarray(sample_weight)
3827
+ sqrt_sw = np.sqrt(sw)
3828
+ X_arr = X_arr * sqrt_sw[:, np.newaxis]
3829
+ y_arr = y_arr * sqrt_sw
3830
+
3831
+ if bool(fit_intercept):
3832
+ X_mean = np.mean(X_arr, axis=0)
3833
+ y_mean = float(np.mean(y_arr))
3834
+ X_centered = X_arr - X_mean
3835
+ y_centered = y_arr - y_mean
3836
+ else:
3837
+ X_mean = np.zeros((X_arr.shape[1],), dtype=np.float64)
3838
+ y_mean = 0.0
3839
+ X_centered = X_arr
3840
+ y_centered = y_arr
3841
+
3842
+ XtX = X_centered.T @ X_centered
3843
+ Xty = X_centered.T @ y_centered
3844
+
3845
+ coefs_desc, n_iters = _solve_lasso_path_cpu_from_gram(
3846
+ XtX,
3847
+ Xty,
3848
+ n_samples=int(X_arr.shape[0]),
3849
+ alphas_desc=alpha_vec,
3850
+ max_iter=int(max_iter),
3851
+ tol=float(tol),
3852
+ stopping=str(stopping),
3853
+ cpu_solver=str(cpu_solver),
3854
+ lipschitz_L=None,
3855
+ cd_kkt_check_every=int(cd_kkt_check_every),
3856
+ )
3857
+
3858
+ coef = np.asarray(coefs_desc[0], dtype=np.float64)
3859
+ if bool(fit_intercept):
3860
+ intercept = float(y_mean - X_mean @ coef)
3861
+ else:
3862
+ intercept = 0.0
3863
+
3864
+ return {
3865
+ "coef": coef,
3866
+ "intercept": float(intercept),
3867
+ "n_iter": int(n_iters[0]),
3868
+ "n_samples": int(X_arr.shape[0]),
3869
+ "n_features": int(X_arr.shape[1]),
3870
+ }
3871
+
3872
+
3873
+ def _select_lasso_alpha_cv(
3874
+ X,
3875
+ y,
3876
+ *,
3877
+ alphas=None,
3878
+ n_alphas: int = 12,
3879
+ alpha_min_ratio: float = 1e-3,
3880
+ cv_folds: int = 5,
3881
+ cv_splits=None,
3882
+ random_state: Optional[int] = None,
3883
+ sample_weight=None,
3884
+ fit_intercept: bool = False,
3885
+ device: Union[str, Device] = Device.CPU,
3886
+ max_iter: int = 3000,
3887
+ tol: float = 1e-4,
3888
+ cpu_solver: str = "coordinate_descent",
3889
+ method: str = "standard",
3890
+ cd_kkt_check_every: Optional[int] = None,
3891
+ gpu_cv_mixed_precision: bool = True,
3892
+ return_details: bool = False,
3893
+ cache_key: Optional[Tuple[Any, ...]] = None,
3894
+ ):
3895
+ """
3896
+ Select alpha via K-fold CV using statgpu's own Lasso implementation.
3897
+
3898
+ Notes
3899
+ -----
3900
+ - Does not depend on sklearn.
3901
+ - Supports GPU path by setting ``device='cuda'``.
3902
+ """
3903
+ device_name = str(device).lower()
3904
+ use_gpu = device_name == Device.CUDA.value
3905
+ gpu_requested = use_gpu
3906
+
3907
+ gpu_input_cupy = False
3908
+ gpu_input_torch = False
3909
+ if use_gpu:
3910
+ # Check if inputs are already on GPU (CuPy or Torch)
3911
+ try:
3912
+ import cupy as cp
3913
+ gpu_input_cupy = isinstance(X, cp.ndarray) and isinstance(y, cp.ndarray)
3914
+ if sample_weight is not None and not isinstance(sample_weight, cp.ndarray):
3915
+ gpu_input_cupy = False
3916
+ except Exception:
3917
+ pass
3918
+
3919
+ # Also check for torch tensors
3920
+ if not gpu_input_cupy:
3921
+ try:
3922
+ import torch
3923
+ gpu_input_torch = isinstance(X, torch.Tensor) and isinstance(y, torch.Tensor)
3924
+ if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
3925
+ gpu_input_torch = False
3926
+ except Exception:
3927
+ pass
3928
+
3929
+ X_np = None
3930
+ y_np = None
3931
+ sample_weight_np = None
3932
+
3933
+ if gpu_input_cupy or gpu_input_torch:
3934
+ # GPU inputs - get backend for validation
3935
+ backend = get_backend(backend='auto', device='cuda')
3936
+ if len(tuple(X.shape)) != 2:
3937
+ raise ValueError("X must be a 2D array")
3938
+ n_samples = int(X.shape[0])
3939
+ y_check = backend.asarray(y).reshape(-1)
3940
+ if int(y_check.shape[0]) != n_samples:
3941
+ raise ValueError("y must have the same number of rows as X")
3942
+ if sample_weight is not None:
3943
+ sw_check = backend.asarray(sample_weight).reshape(-1)
3944
+ if int(sw_check.shape[0]) != n_samples:
3945
+ raise ValueError("sample_weight must have the same number of rows as X")
3946
+ else:
3947
+ X_np = np.asarray(X, dtype=np.float64)
3948
+ y_np = np.asarray(y, dtype=np.float64).reshape(-1)
3949
+ if sample_weight is not None:
3950
+ sample_weight_np = np.asarray(sample_weight, dtype=np.float64).reshape(-1)
3951
+ if X_np.ndim != 2:
3952
+ raise ValueError("X must be a 2D array")
3953
+ if y_np.shape[0] != X_np.shape[0]:
3954
+ raise ValueError("y must have the same number of rows as X")
3955
+ if sample_weight_np is not None and sample_weight_np.shape[0] != X_np.shape[0]:
3956
+ raise ValueError("sample_weight must have the same number of rows as X")
3957
+ n_samples = int(X_np.shape[0])
3958
+
3959
+ cv_method = _normalize_lassocv_method(method)
3960
+ requested_cd_kkt_check_every = _normalize_cd_kkt_check_every(cd_kkt_check_every)
3961
+
3962
+ if alphas is None:
3963
+ if gpu_input_cupy or gpu_input_torch:
3964
+ # Get backend based on input type
3965
+ if gpu_input_torch:
3966
+ backend = get_backend(backend='torch', device='cuda')
3967
+ else:
3968
+ backend = get_backend(backend='cupy', device='cuda')
3969
+ alpha_grid = _default_lasso_alpha_grid_backend(
3970
+ X,
3971
+ y,
3972
+ backend,
3973
+ n_alphas=n_alphas,
3974
+ alpha_min_ratio=alpha_min_ratio,
3975
+ )
3976
+ else:
3977
+ alpha_grid = _default_lasso_alpha_grid(
3978
+ X_np,
3979
+ y_np,
3980
+ n_alphas=n_alphas,
3981
+ alpha_min_ratio=alpha_min_ratio,
3982
+ )
3983
+ else:
3984
+ alpha_grid = np.asarray(alphas, dtype=np.float64).reshape(-1)
3985
+ alpha_grid = alpha_grid[np.isfinite(alpha_grid)]
3986
+ alpha_grid = alpha_grid[alpha_grid > 0.0]
3987
+ if alpha_grid.size == 0:
3988
+ if gpu_input_cupy or gpu_input_torch:
3989
+ # Get backend based on input type
3990
+ if gpu_input_torch:
3991
+ backend = get_backend(backend='torch', device='cuda')
3992
+ else:
3993
+ backend = get_backend(backend='cupy', device='cuda')
3994
+ alpha_grid = _default_lasso_alpha_grid_backend(
3995
+ X,
3996
+ y,
3997
+ backend,
3998
+ n_alphas=n_alphas,
3999
+ alpha_min_ratio=alpha_min_ratio,
4000
+ )
4001
+ else:
4002
+ alpha_grid = _default_lasso_alpha_grid(
4003
+ X_np,
4004
+ y_np,
4005
+ n_alphas=n_alphas,
4006
+ alpha_min_ratio=alpha_min_ratio,
4007
+ )
4008
+
4009
+ user_folds = _normalize_cv_splits(cv_splits, n_samples=n_samples)
4010
+ effective_n_folds = int(len(user_folds)) if user_folds is not None else int(cv_folds)
4011
+
4012
+ if int(n_samples) < 4 or int(alpha_grid.size) == 1 or int(effective_n_folds) < 2:
4013
+ alpha0 = float(alpha_grid[0])
4014
+ if not return_details:
4015
+ return alpha0
4016
+ return {
4017
+ "alpha": alpha0,
4018
+ "alphas": alpha_grid.astype(np.float64, copy=False),
4019
+ "mse_path": np.full((int(alpha_grid.size), 1), np.nan, dtype=np.float64),
4020
+ "mean_mse": np.full(int(alpha_grid.size), np.nan, dtype=np.float64),
4021
+ }
4022
+
4023
+ if user_folds is not None:
4024
+ folds = user_folds
4025
+ else:
4026
+ folds = _kfold_indices(
4027
+ n_samples=int(n_samples),
4028
+ n_splits=int(cv_folds),
4029
+ random_state=random_state,
4030
+ )
4031
+
4032
+ folds_are_complements = _folds_are_complements(folds, n_samples=int(n_samples))
4033
+
4034
+ alpha_grid = alpha_grid.astype(np.float64, copy=False)
4035
+ n_alpha = int(alpha_grid.size)
4036
+ n_folds = int(len(folds))
4037
+
4038
+ cache_key_eff = cache_key
4039
+ if cache_key_eff is None and _LASSO_CV_ALPHA_CACHE_MAXSIZE > 0:
4040
+ cache_key_eff = _make_lasso_cv_auto_cache_key(
4041
+ X=X,
4042
+ y=y,
4043
+ sample_weight=sample_weight,
4044
+ alpha_grid=alpha_grid,
4045
+ folds=folds,
4046
+ fit_intercept=bool(fit_intercept),
4047
+ use_gpu=bool(use_gpu),
4048
+ max_iter=int(max_iter),
4049
+ tol=float(tol),
4050
+ cpu_solver=str(cpu_solver),
4051
+ cv_method=str(cv_method),
4052
+ cd_kkt_check_every=requested_cd_kkt_check_every,
4053
+ gpu_cv_mixed_precision=bool(gpu_cv_mixed_precision),
4054
+ )
4055
+
4056
+ cached_details = _lasso_cv_cache_get(cache_key_eff)
4057
+ if cached_details is not None:
4058
+ if return_details:
4059
+ return cached_details
4060
+ return float(cached_details["alpha"])
4061
+
4062
+ # Evaluate alpha path in descending order for warm-start efficiency.
4063
+ alpha_order_desc = np.argsort(-alpha_grid)
4064
+ alpha_desc = alpha_grid[alpha_order_desc]
4065
+
4066
+ mse_path = np.full((n_alpha, n_folds), np.nan, dtype=np.float64)
4067
+
4068
+ best_alpha = float(alpha_grid[0])
4069
+ best_mse = float("inf")
4070
+
4071
+ if use_gpu:
4072
+ try:
4073
+ # Get backend based on input type - prefer Torch backend for Torch tensors
4074
+ if gpu_input_torch:
4075
+ backend = get_backend(backend='torch', device='cuda')
4076
+ elif gpu_input_cupy:
4077
+ backend = get_backend(backend='cupy', device='cuda')
4078
+ else:
4079
+ backend = get_backend(backend='auto', device='cuda')
4080
+ xp = backend.xp
4081
+
4082
+ cv_dtype = backend.float32 if bool(gpu_cv_mixed_precision) else backend.float64
4083
+
4084
+ # Convert inputs to backend arrays
4085
+ if gpu_input_cupy or gpu_input_torch:
4086
+ # Already on GPU (CuPy or Torch)
4087
+ X_full = backend.asarray(X, dtype=cv_dtype)
4088
+ y_full = backend.asarray(y, dtype=cv_dtype).reshape(-1)
4089
+ if sample_weight is not None:
4090
+ sw_full = backend.asarray(sample_weight, dtype=cv_dtype).reshape(-1)
4091
+ else:
4092
+ sw_full = None
4093
+ else:
4094
+ # Convert from numpy
4095
+ X_full = backend.asarray(X_np, dtype=cv_dtype)
4096
+ y_full = backend.asarray(y_np, dtype=cv_dtype)
4097
+ if sample_weight_np is not None:
4098
+ sw_full = backend.asarray(sample_weight_np, dtype=cv_dtype)
4099
+ else:
4100
+ sw_full = None
4101
+
4102
+ XtX_folds = []
4103
+ Xty_folds = []
4104
+ n_train_folds = []
4105
+ X_mean_folds = []
4106
+ y_mean_folds = []
4107
+ fold_eval_payload = []
4108
+
4109
+ fast_fold_stats = (sw_full is None) and bool(folds_are_complements)
4110
+ if fast_fold_stats:
4111
+ n_total = int(X_full.shape[0])
4112
+ XtX_full = X_full.T @ X_full
4113
+ Xty_full = X_full.T @ y_full
4114
+ if bool(fit_intercept):
4115
+ X_sum_full = backend.sum(X_full, axis=0)
4116
+ y_sum_full = backend.sum(y_full)
4117
+ else:
4118
+ X_sum_full = None
4119
+ y_sum_full = None
4120
+
4121
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
4122
+ train_idx_gpu = backend.asarray(train_idx)
4123
+ val_idx_gpu = backend.asarray(val_idx)
4124
+
4125
+ X_val = X_full[val_idx_gpu]
4126
+ y_val = y_full[val_idx_gpu]
4127
+ sw_val = None if sw_full is None else sw_full[val_idx_gpu]
4128
+
4129
+ if fast_fold_stats:
4130
+ n_val = int(val_idx_gpu.shape[0])
4131
+ n_train = int(n_total - n_val)
4132
+
4133
+ XtX_val = X_val.T @ X_val
4134
+ Xty_val = X_val.T @ y_val
4135
+ XtX_raw = XtX_full - XtX_val
4136
+ Xty_raw = Xty_full - Xty_val
4137
+
4138
+ if bool(fit_intercept):
4139
+ X_sum_val = backend.sum(X_val, axis=0)
4140
+ y_sum_val = backend.sum(y_val)
4141
+ X_sum_train = X_sum_full - X_sum_val
4142
+ y_sum_train = y_sum_full - y_sum_val
4143
+
4144
+ inv_n = backend.asarray(1.0 / float(max(1, n_train)), dtype=X_full.dtype)
4145
+ X_mean = X_sum_train * inv_n
4146
+ y_mean = y_sum_train * inv_n
4147
+ XtX = XtX_raw - backend.outer(X_sum_train, X_sum_train) * inv_n
4148
+ Xty = Xty_raw - X_sum_train * y_mean
4149
+ else:
4150
+ X_mean = backend.zeros((X_full.shape[1],), dtype=X_full.dtype)
4151
+ y_mean = backend.array(0.0, dtype=X_full.dtype)
4152
+ XtX = XtX_raw
4153
+ Xty = Xty_raw
4154
+ else:
4155
+ X_train = X_full[train_idx_gpu]
4156
+ y_train = y_full[train_idx_gpu]
4157
+ sw_train = None if sw_full is None else sw_full[train_idx_gpu]
4158
+
4159
+ if sw_train is not None:
4160
+ sqrt_sw = backend.sqrt(sw_train)
4161
+ X_train = X_train * sqrt_sw[:, None]
4162
+ y_train = y_train * sqrt_sw
4163
+
4164
+ if bool(fit_intercept):
4165
+ X_mean = backend.mean(X_train, axis=0)
4166
+ y_mean = backend.mean(y_train)
4167
+ X_centered = X_train - X_mean
4168
+ y_centered = y_train - y_mean
4169
+ else:
4170
+ X_mean = backend.zeros((X_train.shape[1],), dtype=X_train.dtype)
4171
+ y_mean = backend.array(0.0, dtype=X_train.dtype)
4172
+ X_centered = X_train
4173
+ y_centered = y_train
4174
+
4175
+ XtX = X_centered.T @ X_centered
4176
+ Xty = X_centered.T @ y_centered
4177
+ n_train = int(X_train.shape[0])
4178
+
4179
+ XtX_folds.append(XtX)
4180
+ Xty_folds.append(Xty)
4181
+ n_train_folds.append(int(n_train))
4182
+ X_mean_folds.append(X_mean)
4183
+ y_mean_folds.append(y_mean)
4184
+ fold_eval_payload.append((X_val, y_val, sw_val))
4185
+
4186
+ XtX_batch = backend.stack(XtX_folds, axis=0)
4187
+ Xty_batch = backend.stack(Xty_folds, axis=0)
4188
+
4189
+ # Use native Torch FISTA solver for Torch backend
4190
+ if hasattr(xp, '__name__') and 'torch' in xp.__name__.lower():
4191
+ import torch
4192
+ n_samples_vec_torch = torch.tensor(np.asarray(n_train_folds, dtype=np.int32), device=XtX_batch.device, dtype=XtX_batch.dtype)
4193
+
4194
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
4195
+ XtX_batch,
4196
+ Xty_batch,
4197
+ n_samples_vec=n_samples_vec_torch,
4198
+ alphas_desc=alpha_desc,
4199
+ max_iter=int(max_iter),
4200
+ tol=float(tol),
4201
+ stopping="coef_delta",
4202
+ lipschitz_L=None,
4203
+ check_every=8,
4204
+ )
4205
+
4206
+ # Convert results back to numpy for evaluation
4207
+ for fold_idx in range(int(len(folds))):
4208
+ coefs_desc_np = coefs_batch_desc[fold_idx] # already numpy from the solver
4209
+
4210
+ if bool(fit_intercept):
4211
+ y_mean_val = float(y_mean_folds[fold_idx])
4212
+ X_mean_val = X_mean_folds[fold_idx]
4213
+ intercepts_desc = y_mean_val - X_mean_val @ coefs_desc_np.T
4214
+ intercepts_desc_gpu = backend.asarray(intercepts_desc)
4215
+ coefs_desc_gpu = backend.asarray(coefs_desc_np)
4216
+ else:
4217
+ intercepts_desc_gpu = backend.zeros((coefs_desc_np.shape[0],), dtype=coefs_desc_np.dtype)
4218
+ coefs_desc_gpu = backend.asarray(coefs_desc_np)
4219
+
4220
+ X_val, y_val, sw_val = fold_eval_payload[fold_idx]
4221
+ mse_desc = _batch_mse(X_val, y_val, coefs_desc_gpu, intercepts_desc_gpu, backend, sw_val)
4222
+
4223
+ mse_path[alpha_order_desc, fold_idx] = mse_desc
4224
+ else:
4225
+ # CuPy backend - use existing solver directly
4226
+ import cupy as cp
4227
+ n_samples_vec_cp = cp.asarray(np.asarray(n_train_folds, dtype=np.int32))
4228
+
4229
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
4230
+ XtX_batch,
4231
+ Xty_batch,
4232
+ n_samples_vec=n_samples_vec_cp,
4233
+ alphas_desc=alpha_desc,
4234
+ max_iter=int(max_iter),
4235
+ tol=float(tol),
4236
+ stopping="coef_delta",
4237
+ lipschitz_L=None,
4238
+ check_every=8,
4239
+ )
4240
+
4241
+ for fold_idx in range(int(len(folds))):
4242
+ coefs_desc = coefs_batch_desc[fold_idx]
4243
+
4244
+ if bool(fit_intercept):
4245
+ intercepts_desc = y_mean_folds[fold_idx] - X_mean_folds[fold_idx] @ coefs_desc.T
4246
+ else:
4247
+ intercepts_desc = backend.zeros((coefs_desc.shape[0],), dtype=coefs_desc.dtype)
4248
+
4249
+ X_val, y_val, sw_val = fold_eval_payload[fold_idx]
4250
+ mse_desc = _batch_mse(X_val, y_val, coefs_desc, intercepts_desc, backend, sw_val)
4251
+
4252
+ mse_path[alpha_order_desc, fold_idx] = mse_desc
4253
+
4254
+ except Exception as exc:
4255
+ raise RuntimeError(
4256
+ "GPU path failed in _select_lasso_alpha_cv with device='cuda'; "
4257
+ "CPU fallback is disabled for strict CUDA execution."
4258
+ ) from exc
4259
+
4260
+ if not use_gpu:
4261
+ if gpu_requested:
4262
+ raise RuntimeError(
4263
+ "device='cuda' requested but GPU path was not executed; "
4264
+ "CPU fallback is disabled for strict CUDA execution."
4265
+ )
4266
+ cpu_solver_name = str(cpu_solver).lower()
4267
+
4268
+ if cv_method == "glmnet":
4269
+ # glmnet-like CV profile: coordinate-descent path with periodic full KKT scans.
4270
+ cpu_solver_name = "coordinate_descent"
4271
+
4272
+ if requested_cd_kkt_check_every is None:
4273
+ cd_kkt_check_every_effective = 4 if cv_method == "glmnet" else 1
4274
+ else:
4275
+ cd_kkt_check_every_effective = int(requested_cd_kkt_check_every)
4276
+
4277
+ fast_fold_stats = (sample_weight_np is None) and bool(folds_are_complements)
4278
+ if fast_fold_stats:
4279
+ n_total = int(X_np.shape[0])
4280
+ XtX_full = X_np.T @ X_np
4281
+ Xty_full = X_np.T @ y_np
4282
+ if bool(fit_intercept):
4283
+ X_sum_full = np.sum(X_np, axis=0)
4284
+ y_sum_full = float(np.sum(y_np))
4285
+ else:
4286
+ X_sum_full = None
4287
+ y_sum_full = None
4288
+
4289
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
4290
+ X_val = X_np[val_idx]
4291
+ y_val = y_np[val_idx]
4292
+ sw_val = None if sample_weight_np is None else sample_weight_np[val_idx]
4293
+
4294
+ if fast_fold_stats:
4295
+ n_val = int(np.asarray(val_idx, dtype=np.int64).reshape(-1).size)
4296
+ n_train = int(n_total - n_val)
4297
+
4298
+ XtX_val = X_val.T @ X_val
4299
+ Xty_val = X_val.T @ y_val
4300
+ XtX_raw = XtX_full - XtX_val
4301
+ Xty_raw = Xty_full - Xty_val
4302
+
4303
+ if bool(fit_intercept):
4304
+ X_sum_val = np.sum(X_val, axis=0)
4305
+ y_sum_val = float(np.sum(y_val))
4306
+ X_sum_train = X_sum_full - X_sum_val
4307
+ y_sum_train = y_sum_full - y_sum_val
4308
+
4309
+ inv_n = 1.0 / float(max(1, n_train))
4310
+ X_mean = X_sum_train * inv_n
4311
+ y_mean = y_sum_train * inv_n
4312
+ XtX = XtX_raw - np.outer(X_sum_train, X_sum_train) * inv_n
4313
+ Xty = Xty_raw - X_sum_train * y_mean
4314
+ else:
4315
+ X_mean = np.zeros((X_np.shape[1],), dtype=np.float64)
4316
+ y_mean = 0.0
4317
+ XtX = XtX_raw
4318
+ Xty = Xty_raw
4319
+ else:
4320
+ X_train = X_np[train_idx]
4321
+ y_train = y_np[train_idx]
4322
+ sw_train = None if sample_weight_np is None else sample_weight_np[train_idx]
4323
+
4324
+ if sw_train is not None:
4325
+ sqrt_sw = np.sqrt(sw_train)
4326
+ X_train = X_train * sqrt_sw[:, np.newaxis]
4327
+ y_train = y_train * sqrt_sw
4328
+
4329
+ if bool(fit_intercept):
4330
+ X_mean = np.mean(X_train, axis=0)
4331
+ y_mean = float(np.mean(y_train))
4332
+ X_centered = X_train - X_mean
4333
+ y_centered = y_train - y_mean
4334
+ else:
4335
+ X_mean = np.zeros((X_train.shape[1],), dtype=np.float64)
4336
+ y_mean = 0.0
4337
+ X_centered = X_train
4338
+ y_centered = y_train
4339
+
4340
+ XtX = X_centered.T @ X_centered
4341
+ Xty = X_centered.T @ y_centered
4342
+ n_train = int(X_train.shape[0])
4343
+
4344
+ coefs_desc, _ = _solve_lasso_path_cpu_from_gram(
4345
+ XtX,
4346
+ Xty,
4347
+ n_samples=int(n_train),
4348
+ alphas_desc=alpha_desc,
4349
+ max_iter=int(max_iter),
4350
+ tol=float(tol),
4351
+ stopping="coef_delta",
4352
+ cpu_solver=cpu_solver_name,
4353
+ lipschitz_L=None,
4354
+ cd_kkt_check_every=cd_kkt_check_every_effective,
4355
+ )
4356
+
4357
+ if bool(fit_intercept):
4358
+ intercepts_desc = y_mean - X_mean @ coefs_desc.T
4359
+ else:
4360
+ intercepts_desc = np.zeros((coefs_desc.shape[0],), dtype=np.float64)
4361
+
4362
+ mse_desc = _batch_mse_numpy(
4363
+ X_val,
4364
+ y_val,
4365
+ coefs_desc,
4366
+ intercepts_desc,
4367
+ sw_val,
4368
+ )
4369
+
4370
+ mse_path[alpha_order_desc, fold_idx] = np.asarray(mse_desc, dtype=np.float64)
4371
+
4372
+ for alpha_idx, alpha in enumerate(alpha_grid):
4373
+ alpha_f = float(alpha)
4374
+ valid = np.isfinite(mse_path[alpha_idx])
4375
+ if not bool(np.any(valid)):
4376
+ continue
4377
+
4378
+ mean_mse = float(np.mean(mse_path[alpha_idx, valid]))
4379
+ if mean_mse < best_mse:
4380
+ best_mse = mean_mse
4381
+ best_alpha = alpha_f
4382
+
4383
+ mean_mse_vec = np.full(int(alpha_grid.size), np.nan, dtype=np.float64)
4384
+ for alpha_idx in range(int(alpha_grid.size)):
4385
+ valid = np.isfinite(mse_path[alpha_idx])
4386
+ if bool(np.any(valid)):
4387
+ mean_mse_vec[alpha_idx] = float(np.mean(mse_path[alpha_idx, valid]))
4388
+
4389
+ details = {
4390
+ "alpha": float(best_alpha),
4391
+ "alphas": alpha_grid.astype(np.float64, copy=False),
4392
+ "mse_path": mse_path,
4393
+ "mean_mse": mean_mse_vec,
4394
+ }
4395
+
4396
+ _lasso_cv_cache_put(cache_key_eff, details)
4397
+
4398
+ if return_details:
4399
+ return details
4400
+
4401
+ return float(details["alpha"])
4402
+
4403
+
4404
+ class LassoCV(CVEstimatorBase):
4405
+ """
4406
+ Cross-validated Lasso built on top of statgpu's own ``Lasso`` implementation.
4407
+
4408
+ This class mirrors the common sklearn-style usage pattern while keeping
4409
+ backend/device behavior consistent with statgpu models.
4410
+ """
4411
+
4412
+ def __init__(
4413
+ self,
4414
+ alphas=None,
4415
+ n_alphas: int = 12,
4416
+ alpha_min_ratio: float = 1e-3,
4417
+ cv: int = 5,
4418
+ cv_splits=None,
4419
+ fit_intercept: bool = True,
4420
+ max_iter: int = 3000,
4421
+ tol: float = 1e-4,
4422
+ stopping: str = "coef_delta",
4423
+ inference_method: str = "cpu_ols_inference",
4424
+ n_bootstrap: int = 200,
4425
+ bootstrap_random_state: Optional[int] = None,
4426
+ device: Union[str, Device] = Device.AUTO,
4427
+ n_jobs: Optional[int] = None,
4428
+ compute_inference: bool = True,
4429
+ solver: str = "fista",
4430
+ cpu_solver: str = "coordinate_descent",
4431
+ method: str = "standard",
4432
+ cd_kkt_check_every: Optional[int] = None,
4433
+ lipschitz_L: Optional[float] = None,
4434
+ admm_rho: float = 1.0,
4435
+ gpu_memory_cleanup: bool = False,
4436
+ gpu_cv_mixed_precision: bool = True,
4437
+ random_state: Optional[int] = None,
4438
+ ):
4439
+ super().__init__(
4440
+ cv=cv,
4441
+ random_state=random_state,
4442
+ device=device,
4443
+ n_jobs=n_jobs,
4444
+ )
4445
+ self.alphas = alphas
4446
+ self.n_alphas = int(n_alphas)
4447
+ self.alpha_min_ratio = float(alpha_min_ratio)
4448
+ self.cv = int(cv)
4449
+ self.cv_splits = cv_splits
4450
+ self.fit_intercept = bool(fit_intercept)
4451
+ self.max_iter = int(max_iter)
4452
+ self.tol = float(tol)
4453
+ self.stopping = str(stopping)
4454
+ self.inference_method = str(inference_method)
4455
+ self.n_bootstrap = int(n_bootstrap)
4456
+ self.bootstrap_random_state = bootstrap_random_state
4457
+ self.compute_inference = bool(compute_inference)
4458
+ self.solver = str(solver)
4459
+ self.cpu_solver = str(cpu_solver)
4460
+ self.method = _normalize_lassocv_method(method)
4461
+ self.cd_kkt_check_every = _normalize_cd_kkt_check_every(cd_kkt_check_every)
4462
+ self.lipschitz_L = lipschitz_L
4463
+ self.admm_rho = float(admm_rho)
4464
+ self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
4465
+ self.gpu_cv_mixed_precision = bool(gpu_cv_mixed_precision)
4466
+ self.random_state = random_state
4467
+
4468
+ self.alpha_ = None
4469
+ self.alphas_ = None
4470
+ self.mse_path_ = None
4471
+ self.mean_mse_ = None
4472
+ self.best_score_ = None
4473
+ self.coef_ = None
4474
+ self.intercept_ = None
4475
+ self.n_iter_ = None
4476
+ self.estimator_ = None
4477
+
4478
+ def fit(self, X, y, sample_weight=None):
4479
+ device_name = self._get_compute_device().value
4480
+ effective_cpu_solver = (
4481
+ "coordinate_descent" if str(self.method).lower() == "glmnet" else str(self.cpu_solver)
4482
+ )
4483
+
4484
+ details = _select_lasso_alpha_cv(
4485
+ X,
4486
+ y,
4487
+ alphas=self.alphas,
4488
+ n_alphas=self.n_alphas,
4489
+ alpha_min_ratio=self.alpha_min_ratio,
4490
+ cv_folds=self.cv,
4491
+ cv_splits=self.cv_splits,
4492
+ random_state=self.random_state,
4493
+ sample_weight=sample_weight,
4494
+ fit_intercept=self.fit_intercept,
4495
+ device=device_name,
4496
+ max_iter=self.max_iter,
4497
+ tol=self.tol,
4498
+ cpu_solver=effective_cpu_solver,
4499
+ method=self.method,
4500
+ cd_kkt_check_every=self.cd_kkt_check_every,
4501
+ gpu_cv_mixed_precision=self.gpu_cv_mixed_precision,
4502
+ return_details=True,
4503
+ )
4504
+
4505
+ effective_cd_kkt_check_every = self.cd_kkt_check_every
4506
+ if effective_cd_kkt_check_every is None:
4507
+ effective_cd_kkt_check_every = 4 if str(self.method).lower() == "glmnet" else 1
4508
+
4509
+ self.alpha_ = float(details["alpha"])
4510
+ self.alphas_ = np.asarray(details["alphas"], dtype=np.float64)
4511
+ self.mse_path_ = np.asarray(details["mse_path"], dtype=np.float64)
4512
+ self.mean_mse_ = np.asarray(details["mean_mse"], dtype=np.float64)
4513
+
4514
+ if np.any(np.isfinite(self.mean_mse_)):
4515
+ self.best_score_ = float(np.nanmin(self.mean_mse_))
4516
+ else:
4517
+ self.best_score_ = np.nan
4518
+
4519
+ estimator = Lasso(
4520
+ alpha=self.alpha_,
4521
+ fit_intercept=self.fit_intercept,
4522
+ max_iter=self.max_iter,
4523
+ tol=self.tol,
4524
+ stopping=self.stopping,
4525
+ inference_method=self.inference_method,
4526
+ n_bootstrap=self.n_bootstrap,
4527
+ bootstrap_random_state=self.bootstrap_random_state,
4528
+ device=self.device,
4529
+ n_jobs=self.n_jobs,
4530
+ compute_inference=self.compute_inference,
4531
+ solver=self.solver,
4532
+ cpu_solver=effective_cpu_solver,
4533
+ lipschitz_L=self.lipschitz_L,
4534
+ admm_rho=self.admm_rho,
4535
+ gpu_memory_cleanup=self.gpu_memory_cleanup,
4536
+ )
4537
+
4538
+ fast_refit_enabled = (
4539
+ (not bool(self.compute_inference))
4540
+ and str(self.solver).lower() == "fista"
4541
+ and str(self.stopping).lower() in ("coef_delta", "kkt")
4542
+ )
4543
+
4544
+ if fast_refit_enabled:
4545
+ fast = _fit_lasso_single_alpha_fast(
4546
+ X,
4547
+ y,
4548
+ alpha=float(self.alpha_),
4549
+ fit_intercept=bool(self.fit_intercept),
4550
+ max_iter=int(self.max_iter),
4551
+ tol=float(self.tol),
4552
+ stopping=str(self.stopping),
4553
+ device=str(device_name),
4554
+ cpu_solver=str(effective_cpu_solver),
4555
+ cd_kkt_check_every=int(effective_cd_kkt_check_every),
4556
+ sample_weight=sample_weight,
4557
+ )
4558
+
4559
+ estimator.coef_ = np.asarray(fast["coef"], dtype=np.float64)
4560
+ estimator.intercept_ = float(fast["intercept"])
4561
+ estimator.n_iter_ = int(fast["n_iter"])
4562
+ estimator._nobs = int(fast["n_samples"])
4563
+ estimator._df_resid = int(fast["n_samples"]) - (
4564
+ int(fast["n_features"]) + (1 if bool(self.fit_intercept) else 0)
4565
+ )
4566
+
4567
+ if bool(self.fit_intercept):
4568
+ estimator._params = np.concatenate(
4569
+ [[estimator.intercept_], estimator.coef_]
4570
+ )
4571
+ else:
4572
+ estimator._params = estimator.coef_.copy()
4573
+
4574
+ estimator._scale = np.nan
4575
+ estimator._resid = None
4576
+ estimator._X_design = None
4577
+ estimator._fitted = True
4578
+ else:
4579
+ estimator.fit(X, y, sample_weight=sample_weight)
4580
+
4581
+ self.estimator_ = estimator
4582
+ self.coef_ = np.asarray(estimator.coef_)
4583
+ self.intercept_ = estimator.intercept_
4584
+ self.n_iter_ = int(estimator.n_iter_)
4585
+
4586
+ self._fitted = True
4587
+ return self
4588
+
4589
+ def predict(self, X):
4590
+ self._check_is_fitted()
4591
+ return self.estimator_.predict(X)
4592
+
4593
+ def score(self, X, y):
4594
+ self._check_is_fitted()
4595
+ return self.estimator_.score(X, y)
4596
+
4597
+
4598
+ # =============================================================================
4599
+ # Torch FISTA Solvers
4600
+ # =============================================================================
4601
+
4602
+ def _solve_lasso_path_gpu_fista_batched_from_gram_torch(
4603
+ XtX,
4604
+ Xty,
4605
+ *,
4606
+ n_samples: int,
4607
+ alphas_desc: np.ndarray,
4608
+ max_iter: int,
4609
+ tol: float,
4610
+ stopping: str,
4611
+ lipschitz_L: Optional[float] = None,
4612
+ check_every: int = 8,
4613
+ ):
4614
+ """Solve descending-alpha Lasso path with a batched Torch FISTA update."""
4615
+ import torch
4616
+
4617
+ n_features = int(XtX.shape[0])
4618
+ n_alphas = int(alphas_desc.shape[0])
4619
+
4620
+ coefs = torch.zeros((n_features, n_alphas), dtype=XtX.dtype, device=XtX.device)
4621
+ yk = coefs.clone()
4622
+ tk = torch.ones((n_alphas,), dtype=XtX.dtype, device=XtX.device)
4623
+ n_iters_gpu = torch.zeros((n_alphas,), dtype=torch.int32, device=XtX.device)
4624
+
4625
+ if lipschitz_L is not None:
4626
+ L = torch.tensor(float(lipschitz_L), dtype=XtX.dtype, device=XtX.device)
4627
+ else:
4628
+ try:
4629
+ eigvals = torch.linalg.eigvalsh(XtX)
4630
+ L = eigvals[-1] / float(max(1, n_samples))
4631
+ except Exception:
4632
+ row_sum_bound = torch.max(torch.sum(torch.abs(XtX), dim=1)) / float(max(1, n_samples))
4633
+ L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=XtX.dtype, device=XtX.device))
4634
+
4635
+ L_scalar = float(L.item())
4636
+ if L_scalar <= 0.0:
4637
+ return coefs.T, torch.zeros((n_alphas,), dtype=torch.int32, device=XtX.device).cpu().numpy()
4638
+
4639
+ n_samp = float(max(1, n_samples))
4640
+ step = 1.0 / L
4641
+ alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
4642
+ alpha_gpu = torch.from_numpy(alphas_desc).to(XtX.device).to(XtX.dtype)
4643
+ thresholds = alpha_gpu * step
4644
+ stopping_name = str(stopping).lower()
4645
+ check_every = max(1, int(check_every))
4646
+
4647
+ active_gpu = torch.arange(n_alphas, dtype=torch.int64, device=XtX.device)
4648
+
4649
+ for iteration in range(int(max_iter)):
4650
+ if int(active_gpu.numel()) == 0:
4651
+ break
4652
+
4653
+ y_active = yk[:, active_gpu]
4654
+ coef_old = coefs[:, active_gpu]
4655
+
4656
+ grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
4657
+ thresh = thresholds[active_gpu].reshape(1, -1)
4658
+ coef_new = torch.sign(y_active - step * grad) * torch.maximum(torch.abs(y_active - step * grad) - thresh, torch.tensor(0.0, dtype=XtX.dtype, device=XtX.device))
4659
+
4660
+ t_old = tk[active_gpu]
4661
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
4662
+ beta = (t_old - 1.0) / t_new
4663
+ y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
4664
+
4665
+ coefs[:, active_gpu] = coef_new
4666
+ yk[:, active_gpu] = y_new
4667
+ tk[active_gpu] = t_new
4668
+
4669
+ active_ratio = float(int(active_gpu.numel())) / float(max(1, n_alphas))
4670
+ check_every_eff = _adaptive_gpu_check_every(
4671
+ base_check_every=check_every,
4672
+ iteration=iteration,
4673
+ max_iter=int(max_iter),
4674
+ active_ratio=active_ratio,
4675
+ )
4676
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
4677
+ if not should_check:
4678
+ continue
4679
+
4680
+ if stopping_name == "kkt":
4681
+ grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
4682
+ viol = torch.max(
4683
+ torch.maximum(
4684
+ torch.abs(grad_sse) - alpha_gpu[active_gpu].reshape(1, -1),
4685
+ torch.tensor(0.0, dtype=XtX.dtype, device=XtX.device),
4686
+ ),
4687
+ dim=0,
4688
+ ).values
4689
+ converged_local_gpu = viol < float(tol)
4690
+ else:
4691
+ delta = torch.sum(torch.abs(coef_new - coef_old), dim=0)
4692
+ converged_local_gpu = delta < float(tol)
4693
+
4694
+ done_gpu = active_gpu[converged_local_gpu]
4695
+ if int(done_gpu.numel()) == 0:
4696
+ continue
4697
+
4698
+ n_iters_gpu[done_gpu] = int(iteration) + 1
4699
+ yk[:, done_gpu] = coefs[:, done_gpu]
4700
+ active_gpu = active_gpu[~converged_local_gpu]
4701
+
4702
+ if int(active_gpu.numel()) > 0:
4703
+ n_iters_gpu[active_gpu] = int(max_iter)
4704
+
4705
+ return coefs.T, n_iters_gpu.cpu().numpy()
4706
+
4707
+
4708
+ def _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
4709
+ XtX_batch,
4710
+ Xty_batch,
4711
+ *,
4712
+ n_samples_vec: np.ndarray,
4713
+ alphas_desc: np.ndarray,
4714
+ max_iter: int,
4715
+ tol: float,
4716
+ stopping: str,
4717
+ lipschitz_L: Optional[float] = None,
4718
+ check_every: int = 8,
4719
+ ):
4720
+ """Solve descending-alpha Lasso paths for all folds together on Torch GPU."""
4721
+ import torch
4722
+
4723
+ n_folds = int(XtX_batch.shape[0])
4724
+ n_features = int(XtX_batch.shape[1])
4725
+ n_alphas = int(alphas_desc.shape[0])
4726
+
4727
+ coefs = torch.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
4728
+ yk = coefs.clone()
4729
+ tk = torch.ones((n_folds, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
4730
+ n_iters_gpu = torch.zeros((n_folds, n_alphas), dtype=torch.int32, device=XtX_batch.device)
4731
+
4732
+ n_vec_cpu = n_samples_vec.cpu().numpy().astype(np.float64).reshape(-1)
4733
+ if n_vec_cpu.size != n_folds:
4734
+ raise ValueError("n_samples_vec must have one entry per fold")
4735
+ n_vec = torch.from_numpy(n_vec_cpu).to(XtX_batch.device).to(XtX_batch.dtype)
4736
+
4737
+ if lipschitz_L is not None:
4738
+ L = torch.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype, device=XtX_batch.device)
4739
+ else:
4740
+ try:
4741
+ eigvals = torch.linalg.eigvalsh(XtX_batch)
4742
+ L = eigvals[:, -1] / n_vec
4743
+ except Exception:
4744
+ row_sum_bound = torch.max(torch.sum(torch.abs(XtX_batch), dim=2), dim=1).values / n_vec
4745
+ L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=XtX_batch.dtype, device=XtX_batch.device))
4746
+
4747
+ step = 1.0 / L.reshape(n_folds, 1, 1)
4748
+ alpha_gpu = torch.from_numpy(np.asarray(alphas_desc, dtype=np.float64)).to(XtX_batch.device).to(XtX_batch.dtype).reshape(1, 1, n_alphas)
4749
+ thresholds = alpha_gpu * step
4750
+
4751
+ Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
4752
+ n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
4753
+ stopping_name = str(stopping).lower()
4754
+ check_every = max(1, int(check_every))
4755
+
4756
+ active_gpu = torch.ones((n_folds, n_alphas), dtype=torch.bool, device=XtX_batch.device)
4757
+ active_count = int(n_folds * n_alphas)
4758
+
4759
+ for iteration in range(int(max_iter)):
4760
+ if active_count == 0:
4761
+ break
4762
+
4763
+ active_expanded = active_gpu.unsqueeze(1)
4764
+
4765
+ coef_old = coefs.clone()
4766
+ grad = (torch.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
4767
+ coef_candidate = torch.sign(yk - step * grad) * torch.maximum(torch.abs(yk - step * grad) - thresholds, torch.tensor(0.0, dtype=XtX_batch.dtype, device=XtX_batch.device))
4768
+ coefs = torch.where(active_expanded, coef_candidate, coefs)
4769
+
4770
+ t_old = tk
4771
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
4772
+ beta = (t_old - 1.0) / t_new
4773
+ y_candidate = coefs + beta.unsqueeze(1) * (coefs - coef_old)
4774
+ yk = torch.where(active_expanded, y_candidate, yk)
4775
+ tk = torch.where(active_gpu, t_new, tk)
4776
+
4777
+ active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
4778
+ check_every_eff = _adaptive_gpu_check_every(
4779
+ base_check_every=check_every,
4780
+ iteration=iteration,
4781
+ max_iter=int(max_iter),
4782
+ active_ratio=active_ratio,
4783
+ )
4784
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
4785
+ if not should_check:
4786
+ continue
4787
+
4788
+ if stopping_name == "kkt":
4789
+ grad_sse = (torch.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
4790
+ violation = torch.max(torch.maximum(torch.abs(grad_sse) - alpha_gpu, torch.tensor(0.0, dtype=XtX_batch.dtype, device=XtX_batch.device)), dim=1).values
4791
+ converged_local_gpu = violation < float(tol)
4792
+ else:
4793
+ delta = torch.sum(torch.abs(coefs - coef_old), dim=1)
4794
+ converged_local_gpu = delta < float(tol)
4795
+
4796
+ newly_done_gpu = active_gpu & converged_local_gpu
4797
+ done_count = int(torch.count_nonzero(newly_done_gpu).item())
4798
+ if done_count == 0:
4799
+ continue
4800
+
4801
+ n_iters_gpu[newly_done_gpu] = int(iteration) + 1
4802
+ yk = torch.where(newly_done_gpu.unsqueeze(1), coefs, yk)
4803
+ active_gpu = active_gpu & (~converged_local_gpu)
4804
+ active_count -= done_count
4805
+
4806
+ n_iters_gpu[active_gpu] = int(max_iter)
4807
+
4808
+ return coefs.permute(0, 2, 1), n_iters_gpu.cpu().numpy()
4809
+
4810
+ def summary(self):
4811
+ self._check_is_fitted()
4812
+ return self.estimator_.summary()
4813
+
4814
+
4815
+ # =============================================================================
4816
+ # V9 thin wrapper
4817
+ # =============================================================================
4818
+
4819
+ from ._penalized import PenalizedLinearRegression as _PenalizedLinearRegression
4820
+
4821
+
4822
+ class Lasso(_PenalizedLinearRegression):
4823
+ """Thin sklearn-style wrapper over ``PenalizedLinearRegression`` with L1 penalty."""
4824
+
4825
+ def __init__(
4826
+ self,
4827
+ alpha: float = 1.0,
4828
+ fit_intercept: bool = True,
4829
+ max_iter: int = 1000,
4830
+ tol: float = 1e-4,
4831
+ stopping: str = "coef_delta",
4832
+ inference_method: str = "cpu_ols_inference",
4833
+ n_bootstrap: int = 200,
4834
+ bootstrap_random_state: Optional[int] = None,
4835
+ enable_simultaneous_inference: bool = False,
4836
+ simultaneous_method: str = "maxz_bootstrap",
4837
+ simultaneous_alpha: float = 0.05,
4838
+ simultaneous_n_bootstrap: int = 1000,
4839
+ simultaneous_random_state: Optional[int] = None,
4840
+ simultaneous_include_intercept: bool = False,
4841
+ device: Union[str, Device] = Device.AUTO,
4842
+ n_jobs: Optional[int] = None,
4843
+ compute_inference: bool = True,
4844
+ solver: str = "fista",
4845
+ cpu_solver: str = "coordinate_descent",
4846
+ lipschitz_L: Optional[float] = None,
4847
+ admm_rho: float = 1.0,
4848
+ gpu_memory_cleanup: bool = False,
4849
+ **kwargs,
4850
+ ):
4851
+ self.stopping = str(stopping).lower()
4852
+ self.inference_method = str(inference_method).lower()
4853
+ self.n_bootstrap = int(n_bootstrap)
4854
+ self.bootstrap_random_state = bootstrap_random_state
4855
+ self.enable_simultaneous_inference = bool(enable_simultaneous_inference)
4856
+ self.simultaneous_method = str(simultaneous_method).lower()
4857
+ self.simultaneous_alpha = float(simultaneous_alpha)
4858
+ self.simultaneous_n_bootstrap = int(simultaneous_n_bootstrap)
4859
+ self.simultaneous_random_state = simultaneous_random_state
4860
+ self.simultaneous_include_intercept = bool(simultaneous_include_intercept)
4861
+ self.compute_inference = bool(compute_inference)
4862
+ self.admm_rho = float(admm_rho)
4863
+ self._ignored_kwargs = dict(kwargs)
4864
+ super().__init__(
4865
+ penalty="l1",
4866
+ alpha=alpha,
4867
+ fit_intercept=fit_intercept,
4868
+ max_iter=max_iter,
4869
+ tol=tol,
4870
+ device=device,
4871
+ n_jobs=n_jobs,
4872
+ cpu_solver=cpu_solver,
4873
+ solver=solver,
4874
+ lipschitz_L=lipschitz_L,
4875
+ gpu_memory_cleanup=gpu_memory_cleanup,
4876
+ )