statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2124 @@
1
+ """
2
+ Lasso regression with full statistical inference and GPU support.
3
+ """
4
+
5
+ __all__ = ["Lasso"]
6
+
7
+ from collections import OrderedDict
8
+ import hashlib
9
+ import threading
10
+ from typing import Any, Dict, Optional, Tuple, Union
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+
15
+ try:
16
+ from numba import njit
17
+
18
+ _NUMBA_AVAILABLE = True
19
+ except Exception:
20
+ njit = None
21
+ _NUMBA_AVAILABLE = False
22
+
23
+ from statgpu._base import BaseEstimator
24
+ from statgpu.backends import _to_numpy
25
+ from statgpu._config import Device
26
+ from statgpu.cross_validation._base import CVEstimatorBase, kfold_indices as _kfold_indices, batch_mse as _batch_mse_cv
27
+ from statgpu.backends import get_backend
28
+ from statgpu.inference._distributions_backend import (
29
+ norm,
30
+ t,
31
+ )
32
+
33
+
34
+ _NUMBA_CD_DISABLED = str(os.getenv("STATGPU_DISABLE_NUMBA_CD", "0")).strip().lower() in (
35
+ "1",
36
+ "true",
37
+ "yes",
38
+ "on",
39
+ )
40
+
41
+ _LASSO_CV_ALPHA_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_CV_CACHE_SIZE", "64"))
42
+ _LASSO_CV_ALPHA_CACHE: "OrderedDict[Tuple[Any, ...], Dict[str, Any]]" = OrderedDict()
43
+ _LASSO_DEBIASED_M_CACHE_MAXSIZE = int(os.getenv("STATGPU_LASSO_DEBIASED_M_CACHE_SIZE", "16"))
44
+ _LASSO_DEBIASED_M_CACHE: "OrderedDict[Tuple[Any, ...], np.ndarray]" = OrderedDict()
45
+ _LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK = 1024
46
+ _cache_lock = threading.Lock()
47
+
48
+
49
+ # ============================================================================
50
+ # CuPy Fused Kernels for Lasso - Now implemented as Lasso class methods
51
+ # See Lasso._get_cupy_fused_kernels() for details.
52
+ # ============================================================================
53
+
54
+
55
+ def _debiased_m_cache_get(key):
56
+ with _cache_lock:
57
+ val = _LASSO_DEBIASED_M_CACHE.get(key)
58
+ if val is not None:
59
+ _LASSO_DEBIASED_M_CACHE.move_to_end(key)
60
+ return val
61
+
62
+
63
+ def _debiased_m_cache_put(key, value):
64
+ with _cache_lock:
65
+ _LASSO_DEBIASED_M_CACHE[key] = value
66
+ _LASSO_DEBIASED_M_CACHE.move_to_end(key)
67
+ while len(_LASSO_DEBIASED_M_CACHE) > _LASSO_DEBIASED_M_CACHE_MAXSIZE:
68
+ _LASSO_DEBIASED_M_CACHE.popitem(last=False)
69
+
70
+
71
+ def _debiased_m_key_from_numpy_design(
72
+ X: np.ndarray,
73
+ *,
74
+ n: int,
75
+ p: int,
76
+ lam_nw: float,
77
+ tol: float,
78
+ ):
79
+ X_cache = np.asarray(X)
80
+ if not X_cache.flags["C_CONTIGUOUS"]:
81
+ X_cache = np.ascontiguousarray(X_cache)
82
+ h = hashlib.blake2b(digest_size=32)
83
+ h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
84
+ h.update(str(X_cache.dtype).encode("utf-8"))
85
+ h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
86
+ h.update(X_cache.view(np.uint8).tobytes())
87
+ return h.hexdigest()
88
+
89
+
90
+ def _debiased_m_key_from_sample(
91
+ *,
92
+ n: int,
93
+ p: int,
94
+ dtype_name: str,
95
+ sample_block: np.ndarray,
96
+ lam_nw: float,
97
+ tol: float,
98
+ ):
99
+ """Generate cache key for debiased M matrix from a sample block of X.
100
+
101
+ This is used for Torch backend where we don't want to hash the entire matrix.
102
+ """
103
+ h = hashlib.blake2b(digest_size=32)
104
+ h.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
105
+ h.update(dtype_name.encode("utf-8"))
106
+ h.update(np.asarray([float(lam_nw), float(tol)], dtype=np.float64).tobytes())
107
+ if not sample_block.flags["C_CONTIGUOUS"]:
108
+ sample_block = np.ascontiguousarray(sample_block)
109
+ h.update(sample_block.view(np.uint8).tobytes())
110
+ return h.hexdigest()
111
+
112
+
113
+
114
+ def _lasso_alpha_heuristic(y_centered: np.ndarray, n_features: int) -> float:
115
+ n_samples = int(y_centered.shape[0])
116
+ if n_samples > 1:
117
+ sigma_hat = float(np.std(y_centered, ddof=1))
118
+ else:
119
+ sigma_hat = float(np.std(y_centered))
120
+ sigma_hat = max(sigma_hat, 1e-8)
121
+ penalty_scale = np.sqrt(2.0 * np.log(max(2, int(n_features))) / max(1, n_samples))
122
+ return float(sigma_hat * penalty_scale)
123
+
124
+
125
+ def _default_lasso_alpha_grid(
126
+ X: np.ndarray,
127
+ y: np.ndarray,
128
+ n_alphas: int = 12,
129
+ alpha_min_ratio: float = 1e-3,
130
+ ) -> np.ndarray:
131
+ n_samples = int(X.shape[0])
132
+ corr = np.abs(X.T @ y) / float(max(1, n_samples))
133
+ alpha_max = float(np.max(corr)) if corr.size else 1.0
134
+ alpha_max = max(alpha_max, _lasso_alpha_heuristic(y, n_features=int(X.shape[1])))
135
+ alpha_max = max(alpha_max, 1e-6)
136
+
137
+ if int(n_alphas) <= 1:
138
+ return np.asarray([alpha_max], dtype=np.float64)
139
+
140
+ alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
141
+ return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
142
+
143
+
144
+ def _default_lasso_alpha_grid_backend(
145
+ X,
146
+ y,
147
+ backend,
148
+ n_alphas: int = 12,
149
+ alpha_min_ratio: float = 1e-3,
150
+ ) -> np.ndarray:
151
+ """Generate default alpha grid for Lasso using backend abstraction."""
152
+ X_arr = backend.asarray(X, dtype=backend.float64)
153
+ y_arr = backend.asarray(y, dtype=backend.float64).reshape(-1)
154
+
155
+ n_samples = int(X_arr.shape[0])
156
+ corr = backend.abs(X_arr.T @ y_arr) / float(max(1, n_samples))
157
+ # Use shape to check size - works for both numpy and torch
158
+ corr_size = int(corr.shape[0]) if hasattr(corr, 'shape') else len(corr)
159
+ alpha_max = float(backend.to_numpy(backend.max(corr))) if corr_size > 0 else 1.0
160
+
161
+ if n_samples > 1:
162
+ # Use ddof=1 (sample std) to match numpy _lasso_alpha_heuristic
163
+ y_var = backend.sum((y_arr - backend.mean(y_arr)) ** 2) / (n_samples - 1)
164
+ sigma_hat = float(backend.to_numpy(backend.sqrt(y_var)))
165
+ else:
166
+ sigma_hat = 0.0
167
+
168
+ sigma_hat = max(sigma_hat, 1e-8)
169
+ penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_arr.shape[1]))) / max(1, n_samples))
170
+ alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
171
+
172
+ if int(n_alphas) <= 1:
173
+ return np.asarray([alpha_max], dtype=np.float64)
174
+
175
+ alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
176
+ return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
177
+
178
+
179
+ def _default_lasso_alpha_grid_cupy(
180
+ X,
181
+ y,
182
+ n_alphas: int = 12,
183
+ alpha_min_ratio: float = 1e-3,
184
+ ) -> np.ndarray:
185
+ import cupy as cp
186
+
187
+ X_cp = cp.asarray(X, dtype=cp.float64)
188
+ y_cp = cp.asarray(y, dtype=cp.float64).reshape(-1)
189
+
190
+ n_samples = int(X_cp.shape[0])
191
+ corr = cp.abs(X_cp.T @ y_cp) / float(max(1, n_samples))
192
+ alpha_max = float(cp.max(corr).item()) if int(corr.size) > 0 else 1.0
193
+
194
+ if n_samples > 1:
195
+ sigma_hat = float(cp.std(y_cp, ddof=1).item())
196
+ else:
197
+ sigma_hat = float(cp.std(y_cp).item())
198
+
199
+ sigma_hat = max(sigma_hat, 1e-8)
200
+ penalty_scale = np.sqrt(2.0 * np.log(max(2, int(X_cp.shape[1]))) / max(1, n_samples))
201
+ alpha_max = max(alpha_max, float(sigma_hat * penalty_scale), 1e-6)
202
+
203
+ if int(n_alphas) <= 1:
204
+ return np.asarray([alpha_max], dtype=np.float64)
205
+
206
+ alpha_min = max(float(alpha_min_ratio) * alpha_max, 1e-6)
207
+ return np.geomspace(alpha_max, alpha_min, num=int(n_alphas)).astype(np.float64)
208
+
209
+
210
+ def _normalize_cv_splits(cv_splits, n_samples: int):
211
+ if cv_splits is None:
212
+ return None
213
+
214
+ n = int(n_samples)
215
+ folds = []
216
+
217
+ for split in cv_splits:
218
+ if not isinstance(split, (tuple, list)) or len(split) != 2:
219
+ raise ValueError("Each cv_splits entry must be a (train_idx, val_idx) pair")
220
+
221
+ train_idx = np.asarray(split[0], dtype=np.int64).reshape(-1)
222
+ val_idx = np.asarray(split[1], dtype=np.int64).reshape(-1)
223
+
224
+ if train_idx.size == 0 or val_idx.size == 0:
225
+ continue
226
+
227
+ if (
228
+ bool(np.any(train_idx < 0))
229
+ or bool(np.any(train_idx >= n))
230
+ or bool(np.any(val_idx < 0))
231
+ or bool(np.any(val_idx >= n))
232
+ ):
233
+ raise ValueError("cv_splits indices are out of range")
234
+
235
+ folds.append((train_idx, val_idx))
236
+
237
+ if len(folds) == 0:
238
+ raise ValueError("cv_splits must contain at least one non-empty split")
239
+
240
+ return folds
241
+
242
+
243
+ def _folds_are_complements(folds, n_samples: int) -> bool:
244
+ """Return True when each fold uses train as the exact complement of validation."""
245
+ n = int(n_samples)
246
+ for train_idx, val_idx in folds:
247
+ train_arr = np.asarray(train_idx, dtype=np.int64).reshape(-1)
248
+ val_arr = np.asarray(val_idx, dtype=np.int64).reshape(-1)
249
+
250
+ if int(train_arr.size + val_arr.size) != n:
251
+ return False
252
+
253
+ mask = np.zeros((n,), dtype=np.int8)
254
+ mask[train_arr] = 1
255
+ if bool(np.any(mask[val_arr] != 0)):
256
+ return False
257
+ mask[val_arr] = 1
258
+ if bool(np.any(mask == 0)):
259
+ return False
260
+
261
+ return True
262
+
263
+
264
+ def _array_identity_token(x: Any) -> Tuple[Any, ...]:
265
+ """Content-based hash token for array cache keys.
266
+
267
+ Uses sampled rows (via blake2b digest) to keep hashing fast for large
268
+ arrays while avoiding false cache hits from memory pointer reuse.
269
+ """
270
+ if x is None:
271
+ return ("none",)
272
+
273
+ import hashlib
274
+
275
+ def _hash_bytes(data: bytes) -> str:
276
+ return hashlib.blake2b(data, digest_size=16).hexdigest()
277
+
278
+ def _sample_and_hash(arr_np, n_sample=100):
279
+ """Hash a representative sample of rows for large arrays."""
280
+ n = arr_np.shape[0]
281
+ if n <= n_sample:
282
+ sample = arr_np
283
+ else:
284
+ idx = np.linspace(0, n - 1, n_sample, dtype=int)
285
+ sample = arr_np[idx]
286
+ return _hash_bytes(np.ascontiguousarray(sample).tobytes())
287
+
288
+ try:
289
+ import cupy as cp
290
+
291
+ if isinstance(x, cp.ndarray):
292
+ # Sample on GPU first, then transfer only sampled rows
293
+ n = x.shape[0]
294
+ if n <= 100:
295
+ arr_np = cp.asnumpy(x).astype(np.float64)
296
+ else:
297
+ idx = cp.linspace(0, n - 1, 100, dtype=cp.int64)
298
+ arr_np = cp.asnumpy(x[idx]).astype(np.float64)
299
+ h = _hash_bytes(np.ascontiguousarray(arr_np).tobytes())
300
+ return ("cupy", h, tuple(int(v) for v in x.shape), str(x.dtype))
301
+ except Exception:
302
+ pass
303
+
304
+ # Check for Torch tensors
305
+ try:
306
+ import torch
307
+
308
+ if isinstance(x, torch.Tensor):
309
+ # Sample on GPU first, then transfer only sampled rows
310
+ n = x.shape[0]
311
+ if n <= 100:
312
+ arr_np = x.detach().cpu().numpy().astype(np.float64)
313
+ else:
314
+ idx = torch.linspace(0, n - 1, 100, dtype=torch.long, device=x.device)
315
+ arr_np = x[idx].detach().cpu().numpy().astype(np.float64)
316
+ h = _hash_bytes(np.ascontiguousarray(arr_np).tobytes())
317
+ return ("torch", h, tuple(int(v) for v in x.shape), str(x.dtype))
318
+ except Exception:
319
+ pass
320
+
321
+ arr = np.asarray(x, dtype=np.float64)
322
+ h = _sample_and_hash(arr)
323
+ return ("numpy", h, tuple(int(v) for v in arr.shape), str(arr.dtype))
324
+
325
+
326
+ def _alphas_signature(alphas: np.ndarray) -> str:
327
+ arr = np.ascontiguousarray(np.asarray(alphas, dtype=np.float64).reshape(-1))
328
+ return hashlib.blake2b(arr.tobytes(), digest_size=16).hexdigest()
329
+
330
+
331
+ def _folds_signature(folds) -> str:
332
+ hasher = hashlib.blake2b(digest_size=16)
333
+ for train_idx, val_idx in folds:
334
+ train_arr = np.ascontiguousarray(np.asarray(train_idx, dtype=np.int64).reshape(-1))
335
+ val_arr = np.ascontiguousarray(np.asarray(val_idx, dtype=np.int64).reshape(-1))
336
+ hasher.update(train_arr.tobytes())
337
+ hasher.update(b"|")
338
+ hasher.update(val_arr.tobytes())
339
+ hasher.update(b";")
340
+ return hasher.hexdigest()
341
+
342
+
343
+ def _make_lasso_cv_auto_cache_key(
344
+ *,
345
+ X,
346
+ y,
347
+ sample_weight,
348
+ alpha_grid: np.ndarray,
349
+ folds,
350
+ fit_intercept: bool,
351
+ use_gpu: bool,
352
+ max_iter: int,
353
+ tol: float,
354
+ cpu_solver: str,
355
+ cv_method: str,
356
+ cd_kkt_check_every: Optional[int],
357
+ gpu_cv_mixed_precision: bool,
358
+ ) -> Tuple[Any, ...]:
359
+ return (
360
+ "lasso_cv_auto_v1",
361
+ _array_identity_token(X),
362
+ _array_identity_token(y),
363
+ _array_identity_token(sample_weight),
364
+ _alphas_signature(alpha_grid),
365
+ _folds_signature(folds),
366
+ bool(fit_intercept),
367
+ bool(use_gpu),
368
+ int(max_iter),
369
+ float(tol),
370
+ str(cpu_solver).lower(),
371
+ str(cv_method).lower(),
372
+ None if cd_kkt_check_every is None else int(cd_kkt_check_every),
373
+ bool(gpu_cv_mixed_precision),
374
+ )
375
+
376
+
377
+ def _clone_lasso_cv_cache_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
378
+ return {
379
+ "alpha": float(payload["alpha"]),
380
+ "alphas": np.asarray(payload["alphas"], dtype=np.float64).copy(),
381
+ "mse_path": np.asarray(payload["mse_path"], dtype=np.float64).copy(),
382
+ "mean_mse": np.asarray(payload["mean_mse"], dtype=np.float64).copy(),
383
+ }
384
+
385
+
386
+ def _lasso_cv_cache_get(cache_key: Optional[Tuple[Any, ...]]) -> Optional[Dict[str, Any]]:
387
+ if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
388
+ return None
389
+
390
+ with _cache_lock:
391
+ cached = _LASSO_CV_ALPHA_CACHE.get(cache_key)
392
+ if cached is None:
393
+ return None
394
+ _LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
395
+ return _clone_lasso_cv_cache_payload(cached)
396
+
397
+
398
+ def _lasso_cv_cache_put(cache_key: Optional[Tuple[Any, ...]], payload: Dict[str, Any]) -> None:
399
+ if cache_key is None or _LASSO_CV_ALPHA_CACHE_MAXSIZE <= 0:
400
+ return
401
+
402
+ with _cache_lock:
403
+ _LASSO_CV_ALPHA_CACHE[cache_key] = _clone_lasso_cv_cache_payload(payload)
404
+ _LASSO_CV_ALPHA_CACHE.move_to_end(cache_key)
405
+ while len(_LASSO_CV_ALPHA_CACHE) > int(_LASSO_CV_ALPHA_CACHE_MAXSIZE):
406
+ _LASSO_CV_ALPHA_CACHE.popitem(last=False)
407
+
408
+
409
+ def _adaptive_gpu_check_every(
410
+ *,
411
+ base_check_every: int,
412
+ iteration: int,
413
+ max_iter: int,
414
+ active_ratio: float,
415
+ ) -> int:
416
+ """Adaptive cadence for expensive global convergence checks on GPU."""
417
+ base = max(1, int(base_check_every))
418
+ ratio = float(max(0.0, min(1.0, active_ratio)))
419
+
420
+ if ratio >= 0.75:
421
+ interval = max(base, 16)
422
+ elif ratio >= 0.40:
423
+ interval = max(base, 12)
424
+ elif ratio >= 0.15:
425
+ interval = max(4, base)
426
+ else:
427
+ interval = max(2, base // 2)
428
+
429
+ progress = float(iteration + 1) / float(max(1, int(max_iter)))
430
+ if progress >= 0.90:
431
+ interval = min(interval, 2)
432
+ elif progress >= 0.75:
433
+ interval = min(interval, 4)
434
+
435
+ return max(1, int(interval))
436
+
437
+
438
+ def _soft_threshold_numpy(x: np.ndarray, gamma: float) -> np.ndarray:
439
+ gamma_arr = np.asarray(gamma, dtype=np.float64)
440
+ return np.sign(x) * np.maximum(np.abs(x) - gamma_arr, 0.0)
441
+
442
+
443
+ def _soft_threshold_scalar(x: float, gamma: float) -> float:
444
+ ax = abs(float(x))
445
+ g = float(gamma)
446
+ if ax <= g:
447
+ return 0.0
448
+ return float(np.sign(x) * (ax - g))
449
+
450
+
451
+ if _NUMBA_AVAILABLE:
452
+
453
+ @njit(cache=True)
454
+ def _soft_threshold_scalar_numba(x: float, gamma: float) -> float:
455
+ ax = abs(x)
456
+ if ax <= gamma:
457
+ return 0.0
458
+ if x >= 0.0:
459
+ return ax - gamma
460
+ return -(ax - gamma)
461
+
462
+
463
+ @njit(cache=True)
464
+ def _solve_lasso_path_cpu_cd_numba_impl(
465
+ XtX: np.ndarray,
466
+ Xty: np.ndarray,
467
+ n_samples: int,
468
+ alphas_desc: np.ndarray,
469
+ max_iter: int,
470
+ tol: float,
471
+ stopping_is_kkt: bool,
472
+ cd_kkt_check_every: int,
473
+ ) -> tuple[np.ndarray, np.ndarray]:
474
+ n_features = XtX.shape[0]
475
+ n_alphas = alphas_desc.shape[0]
476
+
477
+ coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
478
+ n_iters = np.zeros((n_alphas,), dtype=np.int32)
479
+
480
+ coef = np.zeros((n_features,), dtype=np.float64)
481
+ grad = -Xty.copy()
482
+
483
+ X_sq_norms = np.empty((n_features,), dtype=np.float64)
484
+ for j in range(n_features):
485
+ X_sq_norms[j] = XtX[j, j]
486
+
487
+ n_samp = float(max(1, n_samples))
488
+ alpha_scaled_desc = np.empty((n_alphas,), dtype=np.float64)
489
+ for idx in range(n_alphas):
490
+ alpha_scaled_desc[idx] = alphas_desc[idx] * n_samp
491
+
492
+ active_mask = np.zeros((n_features,), dtype=np.bool_)
493
+ check_every = max(1, int(cd_kkt_check_every))
494
+
495
+ for alpha_idx in range(n_alphas):
496
+ alpha = float(alphas_desc[alpha_idx])
497
+ alpha_scaled = float(alpha_scaled_desc[alpha_idx])
498
+ if alpha_idx > 0:
499
+ prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1])
500
+ else:
501
+ prev_alpha_scaled = alpha_scaled
502
+
503
+ strong_thresh = 2.0 * alpha_scaled - prev_alpha_scaled
504
+ if strong_thresh < 0.0:
505
+ strong_thresh = 0.0
506
+
507
+ any_active = False
508
+ max_abs_xty = -1.0
509
+ max_abs_xty_idx = 0
510
+ for j in range(n_features):
511
+ abs_xty = abs(Xty[j])
512
+ if abs_xty >= strong_thresh:
513
+ active_mask[j] = True
514
+ any_active = True
515
+ if abs_xty > max_abs_xty:
516
+ max_abs_xty = abs_xty
517
+ max_abs_xty_idx = j
518
+
519
+ if not any_active:
520
+ active_mask[max_abs_xty_idx] = True
521
+
522
+ converged = False
523
+
524
+ for iteration in range(int(max_iter)):
525
+ coef_delta_l1 = 0.0
526
+
527
+ for j in range(n_features):
528
+ if not active_mask[j]:
529
+ continue
530
+
531
+ denom = float(X_sq_norms[j])
532
+ old_val = float(coef[j])
533
+
534
+ if denom > 1e-10:
535
+ rho_j = -float(grad[j]) + denom * old_val
536
+ new_val = _soft_threshold_scalar_numba(rho_j, alpha_scaled) / denom
537
+ else:
538
+ new_val = 0.0
539
+
540
+ delta = new_val - old_val
541
+ if delta != 0.0:
542
+ coef[j] = new_val
543
+ coef_delta_l1 += abs(delta)
544
+ for row_idx in range(n_features):
545
+ grad[row_idx] += XtX[row_idx, j] * delta
546
+
547
+ should_kkt_scan = (
548
+ ((iteration + 1) % check_every == 0)
549
+ or (coef_delta_l1 < float(tol))
550
+ or (iteration + 1 == int(max_iter))
551
+ )
552
+
553
+ violation = 0.0
554
+ has_inactive_violation = False
555
+
556
+ if should_kkt_scan:
557
+ for j in range(n_features):
558
+ v = abs(grad[j] / n_samp) - alpha
559
+ if v < 0.0:
560
+ v = 0.0
561
+ if v > violation:
562
+ violation = v
563
+ if v > float(tol) and (not active_mask[j]):
564
+ active_mask[j] = True
565
+ has_inactive_violation = True
566
+
567
+ if stopping_is_kkt:
568
+ if should_kkt_scan and violation < float(tol):
569
+ n_iters[alpha_idx] = int(iteration) + 1
570
+ converged = True
571
+ break
572
+ else:
573
+ if coef_delta_l1 < float(tol) and (not has_inactive_violation):
574
+ n_iters[alpha_idx] = int(iteration) + 1
575
+ converged = True
576
+ break
577
+
578
+ if not converged:
579
+ n_iters[alpha_idx] = int(max_iter)
580
+
581
+ for j in range(n_features):
582
+ coefs_path[alpha_idx, j] = coef[j]
583
+ if abs(coef[j]) > 0.0:
584
+ active_mask[j] = True
585
+
586
+ return coefs_path, n_iters
587
+
588
+
589
+ def _solve_lasso_path_cpu_cd_numba(
590
+ XtX: np.ndarray,
591
+ Xty: np.ndarray,
592
+ *,
593
+ n_samples: int,
594
+ alphas_desc: np.ndarray,
595
+ max_iter: int,
596
+ tol: float,
597
+ stopping: str,
598
+ cd_kkt_check_every: int,
599
+ ) -> tuple[np.ndarray, np.ndarray]:
600
+ XtX_c = np.ascontiguousarray(XtX, dtype=np.float64)
601
+ Xty_c = np.ascontiguousarray(Xty, dtype=np.float64)
602
+ alphas_c = np.ascontiguousarray(np.asarray(alphas_desc, dtype=np.float64))
603
+ stopping_is_kkt = str(stopping).lower() == "kkt"
604
+ return _solve_lasso_path_cpu_cd_numba_impl(
605
+ XtX_c,
606
+ Xty_c,
607
+ int(n_samples),
608
+ alphas_c,
609
+ int(max_iter),
610
+ float(tol),
611
+ bool(stopping_is_kkt),
612
+ int(cd_kkt_check_every),
613
+ )
614
+
615
+
616
+ def _normalize_lassocv_method(method: str) -> str:
617
+ """Normalize CV optimization profile name."""
618
+ key = str(method).strip().lower()
619
+ alias_map = {
620
+ "default": "standard",
621
+ "classic": "standard",
622
+ "glmnet_cv": "glmnet",
623
+ "glmnet.cv": "glmnet",
624
+ }
625
+ key = alias_map.get(key, key)
626
+ if key not in ("standard", "glmnet"):
627
+ raise ValueError("method must be one of: 'standard', 'glmnet'")
628
+ return key
629
+
630
+
631
+ def _normalize_cd_kkt_check_every(cd_kkt_check_every: Optional[int]) -> Optional[int]:
632
+ """Validate optional coordinate-descent global KKT scan cadence."""
633
+ if cd_kkt_check_every is None:
634
+ return None
635
+ value = int(cd_kkt_check_every)
636
+ if value <= 0:
637
+ raise ValueError("cd_kkt_check_every must be a positive integer or None")
638
+ return value
639
+
640
+
641
+ def _solve_lasso_path_cpu_fista_batched_from_gram(
642
+ XtX: np.ndarray,
643
+ Xty: np.ndarray,
644
+ *,
645
+ n_samples: int,
646
+ alphas_desc: np.ndarray,
647
+ max_iter: int,
648
+ tol: float,
649
+ stopping: str,
650
+ lipschitz_L: Optional[float] = None,
651
+ check_every: int = 2,
652
+ ) -> tuple[np.ndarray, np.ndarray]:
653
+ """Solve descending-alpha Lasso path with a batched CPU FISTA update."""
654
+ n_features = int(XtX.shape[0])
655
+ n_alphas = int(alphas_desc.shape[0])
656
+
657
+ coefs = np.zeros((n_features, n_alphas), dtype=np.float64)
658
+ yk = coefs.copy()
659
+ tk = np.ones((n_alphas,), dtype=np.float64)
660
+ n_iters = np.zeros((n_alphas,), dtype=np.int32)
661
+
662
+ if lipschitz_L is not None:
663
+ L = float(lipschitz_L)
664
+ else:
665
+ try:
666
+ eigvals = np.linalg.eigvalsh(XtX)
667
+ L = float(eigvals[-1] / float(max(1, n_samples)))
668
+ except Exception:
669
+ row_sum_bound = float(np.max(np.sum(np.abs(XtX), axis=1)) / float(max(1, n_samples)))
670
+ L = max(row_sum_bound, 1e-12)
671
+
672
+ if L <= 0.0:
673
+ return coefs.T, n_iters
674
+
675
+ n_samp = float(max(1, n_samples))
676
+ step = 1.0 / L
677
+ alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
678
+ thresholds = alphas_desc * step
679
+ stopping_name = str(stopping).lower()
680
+ check_every = max(1, int(check_every))
681
+
682
+ active = np.arange(n_alphas, dtype=np.int64)
683
+
684
+ for iteration in range(int(max_iter)):
685
+ if active.size == 0:
686
+ break
687
+
688
+ y_active = yk[:, active]
689
+ coef_old = coefs[:, active]
690
+
691
+ grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
692
+ thresh = thresholds[active].reshape(1, -1)
693
+ coef_new = _soft_threshold_numpy(y_active - step * grad, thresh)
694
+
695
+ t_old = tk[active]
696
+ t_new = (1.0 + np.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
697
+ beta = (t_old - 1.0) / t_new
698
+ y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
699
+
700
+ coefs[:, active] = coef_new
701
+ yk[:, active] = y_new
702
+ tk[active] = t_new
703
+
704
+ should_check = ((iteration + 1) % check_every == 0) or (iteration + 1 == int(max_iter))
705
+ if not should_check:
706
+ continue
707
+
708
+ if stopping_name == "kkt":
709
+ grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
710
+ viol = np.max(
711
+ np.maximum(
712
+ np.abs(grad_sse) - alphas_desc[active].reshape(1, -1),
713
+ 0.0,
714
+ ),
715
+ axis=0,
716
+ )
717
+ converged_local = viol < float(tol)
718
+ else:
719
+ delta = np.sum(np.abs(coef_new - coef_old), axis=0)
720
+ converged_local = delta < float(tol)
721
+
722
+ if not np.any(converged_local):
723
+ continue
724
+
725
+ done = active[converged_local]
726
+ n_iters[done] = int(iteration) + 1
727
+ yk[:, done] = coefs[:, done]
728
+ active = active[~converged_local]
729
+
730
+ if active.size > 0:
731
+ n_iters[active] = int(max_iter)
732
+
733
+ return coefs.T, n_iters
734
+
735
+
736
+ def _solve_lasso_path_gpu_fista_batched_from_gram(
737
+ XtX,
738
+ Xty,
739
+ *,
740
+ n_samples: int,
741
+ alphas_desc: np.ndarray,
742
+ max_iter: int,
743
+ tol: float,
744
+ stopping: str,
745
+ lipschitz_L: Optional[float] = None,
746
+ check_every: int = 8,
747
+ ):
748
+ """Solve descending-alpha Lasso path with a batched GPU FISTA update."""
749
+ import cupy as cp
750
+
751
+ n_features = int(XtX.shape[0])
752
+ n_alphas = int(alphas_desc.shape[0])
753
+
754
+ coefs = cp.zeros((n_features, n_alphas), dtype=XtX.dtype)
755
+ yk = coefs.copy()
756
+ tk = cp.ones((n_alphas,), dtype=XtX.dtype)
757
+ n_iters_gpu = cp.zeros((n_alphas,), dtype=cp.int32)
758
+
759
+ if lipschitz_L is not None:
760
+ L = cp.array(float(lipschitz_L), dtype=XtX.dtype)
761
+ else:
762
+ try:
763
+ eigvals = cp.linalg.eigvalsh(XtX)
764
+ L = eigvals[-1] / float(max(1, n_samples))
765
+ except Exception:
766
+ row_sum_bound = cp.max(cp.sum(cp.abs(XtX), axis=1)) / float(max(1, n_samples))
767
+ L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX.dtype))
768
+
769
+ L_scalar = float(cp.asnumpy(L))
770
+ if L_scalar <= 0.0:
771
+ return coefs.T, np.zeros((n_alphas,), dtype=np.int32)
772
+
773
+ n_samp = float(max(1, n_samples))
774
+ step = 1.0 / L
775
+ alphas_desc = np.asarray(alphas_desc, dtype=np.float64)
776
+ alpha_gpu = cp.asarray(alphas_desc, dtype=XtX.dtype)
777
+ thresholds = alpha_gpu * step
778
+ stopping_name = str(stopping).lower()
779
+ check_every = max(1, int(check_every))
780
+
781
+ active_gpu = cp.arange(n_alphas, dtype=cp.int32)
782
+
783
+ for iteration in range(int(max_iter)):
784
+ if int(active_gpu.size) == 0:
785
+ break
786
+
787
+ y_active = yk[:, active_gpu]
788
+ coef_old = coefs[:, active_gpu]
789
+
790
+ grad = (XtX @ y_active - Xty.reshape(-1, 1)) / n_samp
791
+ thresh = thresholds[active_gpu].reshape(1, -1)
792
+ coef_new = cp.sign(y_active - step * grad) * cp.maximum(cp.abs(y_active - step * grad) - thresh, 0.0)
793
+
794
+ t_old = tk[active_gpu]
795
+ t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
796
+ beta = (t_old - 1.0) / t_new
797
+ y_new = coef_new + beta.reshape(1, -1) * (coef_new - coef_old)
798
+
799
+ coefs[:, active_gpu] = coef_new
800
+ yk[:, active_gpu] = y_new
801
+ tk[active_gpu] = t_new
802
+
803
+ active_ratio = float(int(active_gpu.size)) / float(max(1, n_alphas))
804
+ check_every_eff = _adaptive_gpu_check_every(
805
+ base_check_every=check_every,
806
+ iteration=iteration,
807
+ max_iter=int(max_iter),
808
+ active_ratio=active_ratio,
809
+ )
810
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
811
+ if not should_check:
812
+ continue
813
+
814
+ if stopping_name == "kkt":
815
+ grad_sse = (XtX @ coef_new - Xty.reshape(-1, 1)) / n_samp
816
+ viol = cp.max(
817
+ cp.maximum(
818
+ cp.abs(grad_sse) - alpha_gpu[active_gpu].reshape(1, -1),
819
+ 0.0,
820
+ ),
821
+ axis=0,
822
+ )
823
+ converged_local_gpu = viol < float(tol)
824
+ else:
825
+ delta = cp.sum(cp.abs(coef_new - coef_old), axis=0)
826
+ converged_local_gpu = delta < float(tol)
827
+
828
+ done_gpu = active_gpu[converged_local_gpu]
829
+ if int(done_gpu.size) == 0:
830
+ continue
831
+
832
+ n_iters_gpu[done_gpu] = int(iteration) + 1
833
+ yk[:, done_gpu] = coefs[:, done_gpu]
834
+ active_gpu = active_gpu[~converged_local_gpu]
835
+
836
+ if int(active_gpu.size) > 0:
837
+ n_iters_gpu[active_gpu] = int(max_iter)
838
+
839
+ return coefs.T, cp.asnumpy(n_iters_gpu)
840
+
841
+
842
+ def _solve_lasso_path_gpu_fista_multi_fold_from_gram(
843
+ XtX_batch,
844
+ Xty_batch,
845
+ *,
846
+ n_samples_vec,
847
+ alphas_desc,
848
+ max_iter: int,
849
+ tol: float,
850
+ stopping: str,
851
+ lipschitz_L: Optional[float] = None,
852
+ check_every: int = 8,
853
+ ):
854
+ """Solve descending-alpha Lasso paths for all folds together on GPU.
855
+
856
+ Note: Fused kernel optimization is disabled for multi-fold solver due to
857
+ dtype complexity. The single-fold Lasso solver uses fused kernels.
858
+ """
859
+ import cupy as cp
860
+
861
+ n_folds = int(XtX_batch.shape[0])
862
+ n_features = int(XtX_batch.shape[1])
863
+ n_alphas = int(alphas_desc.shape[0])
864
+
865
+ coefs = cp.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype)
866
+ yk = coefs.copy()
867
+ tk = cp.ones((n_folds, n_alphas), dtype=XtX_batch.dtype)
868
+ n_iters_gpu = cp.zeros((n_folds, n_alphas), dtype=cp.int32)
869
+
870
+ # Convert n_samples_vec to numpy using .get() if it's a CuPy array
871
+ if hasattr(n_samples_vec, 'get'):
872
+ n_vec_cpu = n_samples_vec.get().astype(np.float64).reshape(-1)
873
+ else:
874
+ n_vec_cpu = np.asarray(n_samples_vec, dtype=np.float64).reshape(-1)
875
+ if n_vec_cpu.size != n_folds:
876
+ raise ValueError("n_samples_vec must have one entry per fold")
877
+ n_vec = cp.asarray(n_vec_cpu, dtype=XtX_batch.dtype)
878
+
879
+ if lipschitz_L is not None:
880
+ L = cp.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype)
881
+ else:
882
+ try:
883
+ eigvals = cp.linalg.eigvalsh(XtX_batch)
884
+ L = eigvals[:, -1] / n_vec
885
+ except Exception:
886
+ row_sum_bound = cp.max(cp.sum(cp.abs(XtX_batch), axis=2), axis=1) / n_vec
887
+ L = cp.maximum(row_sum_bound, cp.asarray(1e-12, dtype=XtX_batch.dtype))
888
+
889
+ step = 1.0 / L.reshape(n_folds, 1, 1)
890
+ # Convert alphas_desc to numpy using .get() if it's a CuPy array
891
+ if hasattr(alphas_desc, 'get'):
892
+ alphas_cpu = alphas_desc.get().astype(np.float64)
893
+ else:
894
+ alphas_cpu = np.asarray(alphas_desc, dtype=np.float64)
895
+ alpha_gpu = cp.asarray(alphas_cpu, dtype=XtX_batch.dtype).reshape(1, 1, n_alphas)
896
+ thresholds = alpha_gpu * step
897
+
898
+ Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
899
+ n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
900
+ stopping_name = str(stopping).lower()
901
+ check_every = max(1, int(check_every))
902
+
903
+ active_gpu = cp.ones((n_folds, n_alphas), dtype=cp.bool_)
904
+ active_count = int(n_folds * n_alphas)
905
+
906
+ # Note: Fused kernels disabled for multi-fold solver due to dtype complexity
907
+ # The single-fold Lasso._fit_gpu uses fused kernels
908
+ use_fused = False
909
+ fused = None
910
+
911
+ for iteration in range(int(max_iter)):
912
+ if active_count == 0:
913
+ break
914
+
915
+ active_expanded = active_gpu[:, cp.newaxis, :]
916
+
917
+ coef_old = coefs.copy()
918
+ grad = (cp.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
919
+
920
+ # Proximal step: soft thresholding
921
+ yk_step = yk - step * grad
922
+ coef_candidate = cp.sign(yk_step) * cp.maximum(cp.abs(yk_step) - thresholds, 0.0)
923
+ coefs = cp.where(active_expanded, coef_candidate, coefs)
924
+
925
+ t_old = tk
926
+ t_new = (1.0 + cp.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
927
+ beta = (t_old - 1.0) / t_new
928
+ y_candidate = coefs + beta[:, cp.newaxis, :] * (coefs - coef_old)
929
+ yk = cp.where(active_expanded, y_candidate, yk)
930
+ tk = cp.where(active_gpu, t_new, tk)
931
+
932
+ active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
933
+ check_every_eff = _adaptive_gpu_check_every(
934
+ base_check_every=check_every,
935
+ iteration=iteration,
936
+ max_iter=int(max_iter),
937
+ active_ratio=active_ratio,
938
+ )
939
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
940
+ if not should_check:
941
+ continue
942
+
943
+ if stopping_name == "kkt":
944
+ grad_sse = (cp.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
945
+ violation = cp.max(cp.maximum(cp.abs(grad_sse) - alpha_gpu, 0.0), axis=1)
946
+ converged_local_gpu = violation < float(tol)
947
+ else:
948
+ delta = cp.sum(cp.abs(coefs - coef_old), axis=1)
949
+ converged_local_gpu = delta < float(tol)
950
+
951
+ newly_done_gpu = active_gpu & converged_local_gpu
952
+ done_count = int(cp.count_nonzero(newly_done_gpu).item())
953
+ if done_count == 0:
954
+ continue
955
+
956
+ n_iters_gpu[newly_done_gpu] = int(iteration) + 1
957
+ yk = cp.where(newly_done_gpu[:, cp.newaxis, :], coefs, yk)
958
+ active_gpu = active_gpu & (~converged_local_gpu)
959
+ active_count -= done_count
960
+
961
+ n_iters_gpu[active_gpu] = int(max_iter)
962
+
963
+ return cp.transpose(coefs, (0, 2, 1)), cp.asnumpy(n_iters_gpu)
964
+
965
+
966
+ def _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
967
+ XtX_batch,
968
+ Xty_batch,
969
+ *,
970
+ n_samples_vec,
971
+ alphas_desc,
972
+ max_iter: int,
973
+ tol: float,
974
+ stopping: str,
975
+ lipschitz_L: Optional[float] = None,
976
+ check_every: int = 8,
977
+ ):
978
+ """Solve descending-alpha Lasso paths for all folds together on Torch GPU.
979
+
980
+ Mirror of _solve_lasso_path_gpu_fista_multi_fold_from_gram for Torch backend.
981
+ """
982
+ import torch
983
+
984
+ n_folds = int(XtX_batch.shape[0])
985
+ n_features = int(XtX_batch.shape[1])
986
+ n_alphas = int(alphas_desc.shape[0])
987
+
988
+ coefs = torch.zeros((n_folds, n_features, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
989
+ yk = coefs.clone()
990
+ tk = torch.ones((n_folds, n_alphas), dtype=XtX_batch.dtype, device=XtX_batch.device)
991
+ n_iters_gpu = torch.zeros((n_folds, n_alphas), dtype=torch.int32, device=XtX_batch.device)
992
+
993
+ n_vec_cpu = np.asarray(_to_numpy(n_samples_vec), dtype=np.float64).reshape(-1)
994
+ if n_vec_cpu.size != n_folds:
995
+ raise ValueError("n_samples_vec must have one entry per fold")
996
+ n_vec = torch.from_numpy(n_vec_cpu).to(dtype=XtX_batch.dtype, device=XtX_batch.device)
997
+
998
+ if lipschitz_L is not None:
999
+ L = torch.full((n_folds,), float(lipschitz_L), dtype=XtX_batch.dtype, device=XtX_batch.device)
1000
+ else:
1001
+ try:
1002
+ eigvals = torch.linalg.eigvalsh(XtX_batch)
1003
+ L = eigvals[:, -1] / n_vec
1004
+ except Exception:
1005
+ row_sum_bound = torch.max(torch.sum(torch.abs(XtX_batch), dim=2), dim=1).values / n_vec
1006
+ L = torch.maximum(row_sum_bound, torch.tensor(1e-12, dtype=XtX_batch.dtype, device=XtX_batch.device))
1007
+
1008
+ step = 1.0 / L.reshape(n_folds, 1, 1)
1009
+ alphas_cpu = np.asarray(_to_numpy(alphas_desc), dtype=np.float64)
1010
+ alpha_gpu = torch.from_numpy(alphas_cpu).to(dtype=XtX_batch.dtype, device=XtX_batch.device).reshape(1, 1, n_alphas)
1011
+ thresholds = alpha_gpu * step
1012
+
1013
+ Xty_expanded = Xty_batch.reshape(n_folds, n_features, 1)
1014
+ n_vec_expanded = n_vec.reshape(n_folds, 1, 1)
1015
+ stopping_name = str(stopping).lower()
1016
+ check_every = max(1, int(check_every))
1017
+
1018
+ active_gpu = torch.ones((n_folds, n_alphas), dtype=torch.bool, device=XtX_batch.device)
1019
+ active_count = int(n_folds * n_alphas)
1020
+
1021
+ for iteration in range(int(max_iter)):
1022
+ if active_count == 0:
1023
+ break
1024
+
1025
+ active_expanded = active_gpu.unsqueeze(1)
1026
+
1027
+ coef_old = coefs.clone()
1028
+ grad = (torch.matmul(XtX_batch, yk) - Xty_expanded) / n_vec_expanded
1029
+
1030
+ # Proximal step: soft thresholding
1031
+ yk_step = yk - step * grad
1032
+ coef_candidate = torch.sign(yk_step) * torch.maximum(torch.abs(yk_step) - thresholds, torch.tensor(0.0, device=XtX_batch.device))
1033
+ coefs = torch.where(active_expanded, coef_candidate, coefs)
1034
+
1035
+ t_old = tk
1036
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * (t_old ** 2))) / 2.0
1037
+ beta = (t_old - 1.0) / t_new
1038
+ y_candidate = coefs + beta.unsqueeze(1) * (coefs - coef_old)
1039
+ yk = torch.where(active_expanded, y_candidate, yk)
1040
+ tk = torch.where(active_gpu, t_new, tk)
1041
+
1042
+ active_ratio = float(active_count) / float(max(1, n_folds * n_alphas))
1043
+ check_every_eff = _adaptive_gpu_check_every(
1044
+ base_check_every=check_every,
1045
+ iteration=iteration,
1046
+ max_iter=int(max_iter),
1047
+ active_ratio=active_ratio,
1048
+ )
1049
+ should_check = ((iteration + 1) % check_every_eff == 0) or (iteration + 1 == int(max_iter))
1050
+ if not should_check:
1051
+ continue
1052
+
1053
+ if stopping_name == "kkt":
1054
+ grad_sse = (torch.matmul(XtX_batch, coefs) - Xty_expanded) / n_vec_expanded
1055
+ violation = torch.max(torch.maximum(torch.abs(grad_sse) - alpha_gpu, torch.tensor(0.0, device=XtX_batch.device)), dim=1).values
1056
+ converged_local_gpu = violation < float(tol)
1057
+ else:
1058
+ delta = torch.sum(torch.abs(coefs - coef_old), dim=1)
1059
+ converged_local_gpu = delta < float(tol)
1060
+
1061
+ newly_done_gpu = active_gpu & converged_local_gpu
1062
+ done_count = int(torch.count_nonzero(newly_done_gpu).item())
1063
+ if done_count == 0:
1064
+ continue
1065
+
1066
+ n_iters_gpu[newly_done_gpu] = int(iteration) + 1
1067
+ yk = torch.where(newly_done_gpu.unsqueeze(1), coefs, yk)
1068
+ active_gpu = active_gpu & (~converged_local_gpu)
1069
+ active_count -= done_count
1070
+
1071
+ n_iters_gpu[active_gpu] = int(max_iter)
1072
+
1073
+ return coefs.permute(0, 2, 1), n_iters_gpu.cpu().numpy()
1074
+
1075
+
1076
+ def _solve_lasso_path_cpu_from_gram(
1077
+ XtX: np.ndarray,
1078
+ Xty: np.ndarray,
1079
+ *,
1080
+ n_samples: int,
1081
+ alphas_desc: np.ndarray,
1082
+ max_iter: int,
1083
+ tol: float,
1084
+ stopping: str,
1085
+ cpu_solver: str,
1086
+ lipschitz_L: Optional[float] = None,
1087
+ cd_kkt_check_every: int = 1,
1088
+ ) -> tuple[np.ndarray, np.ndarray]:
1089
+ """Solve a descending-alpha Lasso path on CPU using one precomputed Gram matrix."""
1090
+ n_features = int(XtX.shape[0])
1091
+ n_alphas = int(alphas_desc.shape[0])
1092
+
1093
+ coefs_path = np.zeros((n_alphas, n_features), dtype=np.float64)
1094
+ n_iters = np.zeros(n_alphas, dtype=np.int32)
1095
+
1096
+ coef = np.zeros(n_features, dtype=np.float64)
1097
+ stopping_name = str(stopping).lower()
1098
+ solver_name = str(cpu_solver).lower()
1099
+
1100
+ if solver_name == "fista":
1101
+ return _solve_lasso_path_cpu_fista_batched_from_gram(
1102
+ XtX,
1103
+ Xty,
1104
+ n_samples=n_samples,
1105
+ alphas_desc=alphas_desc,
1106
+ max_iter=max_iter,
1107
+ tol=tol,
1108
+ stopping=stopping,
1109
+ lipschitz_L=lipschitz_L,
1110
+ check_every=2,
1111
+ )
1112
+
1113
+ global _NUMBA_CD_DISABLED
1114
+ use_numba_cd = (
1115
+ _NUMBA_AVAILABLE
1116
+ and (not _NUMBA_CD_DISABLED)
1117
+ and solver_name == "coordinate_descent"
1118
+ )
1119
+
1120
+ if use_numba_cd:
1121
+ try:
1122
+ return _solve_lasso_path_cpu_cd_numba(
1123
+ XtX,
1124
+ Xty,
1125
+ n_samples=n_samples,
1126
+ alphas_desc=alphas_desc,
1127
+ max_iter=max_iter,
1128
+ tol=tol,
1129
+ stopping=stopping,
1130
+ cd_kkt_check_every=cd_kkt_check_every,
1131
+ )
1132
+ except Exception:
1133
+ _NUMBA_CD_DISABLED = True
1134
+
1135
+ # Coordinate descent with incremental gradient updates.
1136
+ X_sq_norms = np.diag(XtX).astype(np.float64, copy=False)
1137
+ grad = XtX @ coef - Xty
1138
+ alpha_scaled_desc = np.asarray(alphas_desc, dtype=np.float64) * float(max(1, n_samples))
1139
+ active_mask = np.zeros((n_features,), dtype=bool)
1140
+ cd_kkt_check_every = max(1, int(cd_kkt_check_every))
1141
+
1142
+ for alpha_idx, alpha in enumerate(alphas_desc):
1143
+ alpha_scaled = float(alpha_scaled_desc[alpha_idx])
1144
+ prev_alpha_scaled = float(alpha_scaled_desc[alpha_idx - 1]) if alpha_idx > 0 else alpha_scaled
1145
+
1146
+ # Strong rule screening: expand active set before cyclic updates.
1147
+ strong_thresh = max(0.0, 2.0 * alpha_scaled - prev_alpha_scaled)
1148
+ active_mask |= np.abs(Xty) >= strong_thresh
1149
+ if not bool(np.any(active_mask)):
1150
+ active_mask[int(np.argmax(np.abs(Xty)))] = True
1151
+
1152
+ converged = False
1153
+
1154
+ for iteration in range(int(max_iter)):
1155
+ coef_delta_l1 = 0.0
1156
+
1157
+ active_idx = np.flatnonzero(active_mask)
1158
+ for j in active_idx:
1159
+ denom = float(X_sq_norms[j])
1160
+ old_val = float(coef[j])
1161
+
1162
+ if denom > 1e-10:
1163
+ rho_j = -float(grad[j]) + denom * old_val
1164
+ new_val = _soft_threshold_scalar(rho_j, alpha_scaled) / denom
1165
+ else:
1166
+ new_val = 0.0
1167
+
1168
+ delta = new_val - old_val
1169
+ if abs(delta) > 0.0:
1170
+ coef[j] = new_val
1171
+ grad += XtX[:, j] * delta
1172
+ coef_delta_l1 += abs(delta)
1173
+
1174
+ # glmnet-style optimization can skip full inactive KKT scans on every pass,
1175
+ # then force a check when updates become small.
1176
+ should_kkt_scan = (
1177
+ ((iteration + 1) % cd_kkt_check_every == 0)
1178
+ or (coef_delta_l1 < float(tol))
1179
+ or (iteration + 1 == int(max_iter))
1180
+ )
1181
+ violation = float("inf")
1182
+ inactive_violation_idx = np.empty((0,), dtype=np.int64)
1183
+
1184
+ if should_kkt_scan:
1185
+ violation_vec = np.maximum(
1186
+ np.abs(grad / float(max(1, n_samples))) - float(alpha),
1187
+ 0.0,
1188
+ )
1189
+ inactive_violation_idx = np.where((violation_vec > float(tol)) & (~active_mask))[0]
1190
+ if inactive_violation_idx.size > 0:
1191
+ active_mask[inactive_violation_idx] = True
1192
+ violation = float(np.max(violation_vec))
1193
+
1194
+ if stopping_name == "kkt":
1195
+ if should_kkt_scan and violation < float(tol):
1196
+ n_iters[alpha_idx] = iteration + 1
1197
+ converged = True
1198
+ break
1199
+ else:
1200
+ if coef_delta_l1 < float(tol) and inactive_violation_idx.size == 0:
1201
+ n_iters[alpha_idx] = iteration + 1
1202
+ converged = True
1203
+ break
1204
+
1205
+ if not converged:
1206
+ n_iters[alpha_idx] = int(max_iter)
1207
+
1208
+ coefs_path[alpha_idx, :] = coef
1209
+ active_mask |= np.abs(coef) > 0.0
1210
+
1211
+ return coefs_path, n_iters
1212
+
1213
+
1214
+ def _solve_lasso_path_gpu_from_gram(
1215
+ XtX,
1216
+ Xty,
1217
+ *,
1218
+ n_samples: int,
1219
+ alphas_desc: np.ndarray,
1220
+ max_iter: int,
1221
+ tol: float,
1222
+ stopping: str,
1223
+ lipschitz_L: Optional[float] = None,
1224
+ check_every: int = 8,
1225
+ ):
1226
+ """Solve a descending-alpha Lasso path on GPU using one precomputed Gram matrix."""
1227
+ return _solve_lasso_path_gpu_fista_batched_from_gram(
1228
+ XtX,
1229
+ Xty,
1230
+ n_samples=n_samples,
1231
+ alphas_desc=alphas_desc,
1232
+ max_iter=max_iter,
1233
+ tol=tol,
1234
+ stopping=stopping,
1235
+ lipschitz_L=lipschitz_L,
1236
+ check_every=check_every,
1237
+ )
1238
+
1239
+
1240
+ def _soft_threshold_torch(x, gamma):
1241
+ """Soft thresholding operator for Torch tensors."""
1242
+ import torch
1243
+ return torch.sign(x) * torch.maximum(torch.abs(x) - gamma, torch.tensor(0.0, dtype=x.dtype, device=x.device))
1244
+
1245
+
1246
+ def _fit_lasso_single_alpha_fast(
1247
+ X,
1248
+ y,
1249
+ *,
1250
+ alpha: float,
1251
+ fit_intercept: bool,
1252
+ max_iter: int,
1253
+ tol: float,
1254
+ stopping: str,
1255
+ device: str,
1256
+ cpu_solver: str,
1257
+ cd_kkt_check_every: int = 1,
1258
+ sample_weight=None,
1259
+ ) -> Dict[str, object]:
1260
+ """Fast single-alpha Lasso fit using optimized Gram-based path solvers."""
1261
+ device_name = str(device).lower()
1262
+ alpha_vec = np.asarray([float(alpha)], dtype=np.float64)
1263
+
1264
+ # Check if inputs are torch tensors on GPU
1265
+ is_torch_gpu = False
1266
+ try:
1267
+ import torch
1268
+ is_torch_gpu = device_name == Device.CUDA.value and isinstance(X, torch.Tensor)
1269
+ except Exception:
1270
+ pass
1271
+
1272
+ if device_name == Device.CUDA.value and not is_torch_gpu:
1273
+ # CuPy GPU path
1274
+ import cupy as cp
1275
+
1276
+ X_arr = cp.asarray(X)
1277
+ y_arr = cp.asarray(y).reshape(-1)
1278
+ sw = None
1279
+
1280
+ if sample_weight is not None:
1281
+ sw = cp.asarray(sample_weight)
1282
+ sqrt_sw = cp.sqrt(sw)
1283
+ X_arr = X_arr * sqrt_sw[:, cp.newaxis]
1284
+ y_arr = y_arr * sqrt_sw
1285
+
1286
+ if bool(fit_intercept):
1287
+ if sw is not None:
1288
+ # Weighted mean on original (pre-sqrt) data
1289
+ X_orig = X_arr / sqrt_sw[:, cp.newaxis]
1290
+ y_orig = y_arr / sqrt_sw
1291
+ w_sum = float(cp.sum(sw))
1292
+ X_mean = cp.sum(X_orig * sw[:, cp.newaxis], axis=0) / w_sum
1293
+ y_mean = float(cp.sum(y_orig * sw)) / w_sum
1294
+ X_centered = X_arr - sqrt_sw[:, cp.newaxis] * X_mean
1295
+ y_centered = y_arr - sqrt_sw * y_mean
1296
+ else:
1297
+ X_mean = cp.mean(X_arr, axis=0)
1298
+ y_mean = cp.mean(y_arr)
1299
+ X_centered = X_arr - X_mean
1300
+ y_centered = y_arr - y_mean
1301
+ else:
1302
+ X_mean = cp.zeros((X_arr.shape[1],), dtype=X_arr.dtype)
1303
+ y_mean = cp.array(0.0, dtype=X_arr.dtype)
1304
+ X_centered = X_arr
1305
+ y_centered = y_arr
1306
+
1307
+ XtX = X_centered.T @ X_centered
1308
+ Xty = X_centered.T @ y_centered
1309
+
1310
+ coefs_desc, n_iters = _solve_lasso_path_gpu_from_gram(
1311
+ XtX,
1312
+ Xty,
1313
+ n_samples=int(X_arr.shape[0]),
1314
+ alphas_desc=alpha_vec,
1315
+ max_iter=int(max_iter),
1316
+ tol=float(tol),
1317
+ stopping=str(stopping),
1318
+ lipschitz_L=None,
1319
+ check_every=8,
1320
+ )
1321
+
1322
+ coef_gpu = coefs_desc[0]
1323
+ if bool(fit_intercept):
1324
+ intercept_gpu = y_mean - X_mean @ coef_gpu
1325
+ intercept = float(cp.asnumpy(intercept_gpu))
1326
+ else:
1327
+ intercept = 0.0
1328
+
1329
+ coef = np.asarray(cp.asnumpy(coef_gpu), dtype=np.float64)
1330
+ return {
1331
+ "coef": coef,
1332
+ "intercept": float(intercept),
1333
+ "n_iter": int(n_iters[0]),
1334
+ "n_samples": int(X_arr.shape[0]),
1335
+ "n_features": int(X_arr.shape[1]),
1336
+ }
1337
+
1338
+ elif is_torch_gpu:
1339
+ # Torch GPU path - use FISTA solver directly on GPU tensors
1340
+ import torch
1341
+
1342
+ X_arr = X
1343
+ y_arr = y.reshape(-1) if isinstance(y, torch.Tensor) else torch.as_tensor(
1344
+ y, dtype=X_arr.dtype, device=X_arr.device
1345
+ ).reshape(-1)
1346
+ sw = None
1347
+
1348
+ if sample_weight is not None:
1349
+ sw = sample_weight if isinstance(sample_weight, torch.Tensor) else torch.as_tensor(
1350
+ sample_weight, dtype=X_arr.dtype, device=X_arr.device
1351
+ )
1352
+ sqrt_sw = torch.sqrt(sw)
1353
+ X_arr = X_arr * sqrt_sw[:, None]
1354
+ y_arr = y_arr * sqrt_sw
1355
+
1356
+ if bool(fit_intercept):
1357
+ if sw is not None:
1358
+ # Weighted mean: sum(w*X)/sum(w) on original (pre-sqrt) data
1359
+ # But X_arr is already sqrt(w)*X, so mean of sqrt(w)*X is not
1360
+ # the weighted mean. Use the original data for centering.
1361
+ X_orig = X_arr / sqrt_sw[:, None]
1362
+ y_orig = y_arr / sqrt_sw
1363
+ w_sum = float(sw.sum())
1364
+ X_mean = torch.sum(X_orig * sw[:, None], dim=0) / w_sum
1365
+ y_mean = float(torch.sum(y_orig * sw)) / w_sum
1366
+ # Re-center the sqrt-weighted data using the weighted mean
1367
+ X_centered = X_arr - sqrt_sw[:, None] * X_mean
1368
+ y_centered = y_arr - sqrt_sw * y_mean
1369
+ else:
1370
+ X_mean = torch.mean(X_arr, dim=0)
1371
+ y_mean = torch.mean(y_arr)
1372
+ X_centered = X_arr - X_mean
1373
+ y_centered = y_arr - y_mean
1374
+ else:
1375
+ X_mean = torch.zeros((X_arr.shape[1],), dtype=X_arr.dtype, device=X_arr.device)
1376
+ y_mean = torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)
1377
+ X_centered = X_arr
1378
+ y_centered = y_arr
1379
+
1380
+ n_samples = int(X_arr.shape[0])
1381
+ n_features = int(X_arr.shape[1])
1382
+
1383
+ # Precompute Gram matrix and X'y for FISTA gradient
1384
+ XtX = X_centered.T @ X_centered
1385
+ Xty = X_centered.T @ y_centered
1386
+
1387
+ # Compute Lipschitz constant L = max eigenvalue of XtX / n
1388
+ try:
1389
+ eigvals = torch.linalg.eigvalsh(XtX)
1390
+ L = eigvals[-1] / n_samples
1391
+ except Exception:
1392
+ L = torch.sum(X_centered ** 2) / n_samples
1393
+ L = torch.clamp(L, min=1e-10)
1394
+
1395
+ step = 1.0 / L
1396
+ thresh = float(alpha) * step
1397
+
1398
+ # FISTA initialization
1399
+ coef = torch.zeros(n_features, dtype=X_arr.dtype, device=X_arr.device)
1400
+ z = coef.clone()
1401
+ t = torch.tensor(1.0, dtype=X_arr.dtype, device=X_arr.device)
1402
+
1403
+ # FISTA iterations
1404
+ for iteration in range(int(max_iter)):
1405
+ coef_old = coef.clone()
1406
+
1407
+ # Gradient step at z
1408
+ grad = (XtX @ z - Xty) / n_samples
1409
+ coef = _soft_threshold_torch(z - step * grad, thresh)
1410
+
1411
+ # Momentum update
1412
+ t_new = (1.0 + torch.sqrt(1.0 + 4.0 * t ** 2)) / 2.0
1413
+ z = coef + ((t - 1.0) / t_new) * (coef - coef_old)
1414
+ t = t_new
1415
+
1416
+ # Convergence check
1417
+ if str(stopping).lower() == "kkt":
1418
+ grad_sse = (XtX @ coef - Xty) / n_samples
1419
+ violation = torch.max(torch.maximum(torch.abs(grad_sse) - float(alpha), torch.tensor(0.0, dtype=X_arr.dtype, device=X_arr.device)))
1420
+ if violation < float(tol):
1421
+ break
1422
+ else:
1423
+ if torch.sum(torch.abs(coef - coef_old)) < float(tol):
1424
+ break
1425
+
1426
+ # Build coefficients
1427
+ if bool(fit_intercept):
1428
+ intercept_torch = y_mean - X_mean @ coef
1429
+ intercept = float(intercept_torch.item())
1430
+ else:
1431
+ intercept = 0.0
1432
+
1433
+ coef_np = np.asarray(coef.detach().cpu().numpy(), dtype=np.float64)
1434
+ return {
1435
+ "coef": coef_np,
1436
+ "intercept": float(intercept),
1437
+ "n_iter": int(iteration + 1),
1438
+ "n_samples": n_samples,
1439
+ "n_features": n_features,
1440
+ }
1441
+
1442
+ X_arr = np.asarray(X)
1443
+ y_arr = np.asarray(y).reshape(-1)
1444
+
1445
+ if sample_weight is not None:
1446
+ sw = np.asarray(sample_weight)
1447
+ sqrt_sw = np.sqrt(sw)
1448
+ X_arr = X_arr * sqrt_sw[:, np.newaxis]
1449
+ y_arr = y_arr * sqrt_sw
1450
+
1451
+ if bool(fit_intercept):
1452
+ if sample_weight is not None:
1453
+ # Weighted mean on original (pre-sqrt) data
1454
+ sw = np.asarray(sample_weight)
1455
+ w_sum = float(np.sum(sw))
1456
+ X_orig = X_arr / sqrt_sw[:, np.newaxis]
1457
+ y_orig = y_arr / sqrt_sw
1458
+ X_mean = np.sum(X_orig * sw[:, np.newaxis], axis=0) / w_sum
1459
+ y_mean = float(np.sum(y_orig * sw)) / w_sum
1460
+ # Center the sqrt-weighted data using the weighted mean
1461
+ X_centered = X_arr - sqrt_sw[:, np.newaxis] * X_mean
1462
+ y_centered = y_arr - sqrt_sw * y_mean
1463
+ else:
1464
+ X_mean = np.mean(X_arr, axis=0)
1465
+ y_mean = float(np.mean(y_arr))
1466
+ X_centered = X_arr - X_mean
1467
+ y_centered = y_arr - y_mean
1468
+ else:
1469
+ X_mean = np.zeros((X_arr.shape[1],), dtype=np.float64)
1470
+ y_mean = 0.0
1471
+ X_centered = X_arr
1472
+ y_centered = y_arr
1473
+
1474
+ XtX = X_centered.T @ X_centered
1475
+ Xty = X_centered.T @ y_centered
1476
+
1477
+ coefs_desc, n_iters = _solve_lasso_path_cpu_from_gram(
1478
+ XtX,
1479
+ Xty,
1480
+ n_samples=int(X_arr.shape[0]),
1481
+ alphas_desc=alpha_vec,
1482
+ max_iter=int(max_iter),
1483
+ tol=float(tol),
1484
+ stopping=str(stopping),
1485
+ cpu_solver=str(cpu_solver),
1486
+ lipschitz_L=None,
1487
+ cd_kkt_check_every=int(cd_kkt_check_every),
1488
+ )
1489
+
1490
+ coef = np.asarray(coefs_desc[0], dtype=np.float64)
1491
+ if bool(fit_intercept):
1492
+ intercept = float(y_mean - X_mean @ coef)
1493
+ else:
1494
+ intercept = 0.0
1495
+
1496
+ return {
1497
+ "coef": coef,
1498
+ "intercept": float(intercept),
1499
+ "n_iter": int(n_iters[0]),
1500
+ "n_samples": int(X_arr.shape[0]),
1501
+ "n_features": int(X_arr.shape[1]),
1502
+ }
1503
+
1504
+
1505
+ def _select_lasso_alpha_cv(
1506
+ X,
1507
+ y,
1508
+ *,
1509
+ alphas=None,
1510
+ n_alphas: int = 12,
1511
+ alpha_min_ratio: float = 1e-3,
1512
+ cv_folds: int = 5,
1513
+ cv_splits=None,
1514
+ random_state: Optional[int] = None,
1515
+ sample_weight=None,
1516
+ fit_intercept: bool = False,
1517
+ device: Union[str, Device] = Device.CPU,
1518
+ max_iter: int = 3000,
1519
+ tol: float = 1e-4,
1520
+ cpu_solver: str = "coordinate_descent",
1521
+ method: str = "standard",
1522
+ cd_kkt_check_every: Optional[int] = None,
1523
+ gpu_cv_mixed_precision: bool = True,
1524
+ return_details: bool = False,
1525
+ cache_key: Optional[Tuple[Any, ...]] = None,
1526
+ ):
1527
+ """
1528
+ Select alpha via K-fold CV using statgpu's own Lasso implementation.
1529
+
1530
+ Notes
1531
+ -----
1532
+ - Does not depend on sklearn.
1533
+ - Supports GPU path by setting ``device='cuda'``.
1534
+ """
1535
+ device_name = str(device).lower()
1536
+ use_gpu = device_name in (Device.CUDA.value, Device.TORCH.value)
1537
+ gpu_requested = use_gpu
1538
+
1539
+ gpu_input_cupy = False
1540
+ gpu_input_torch = False
1541
+ if use_gpu:
1542
+ # Check if inputs are already on GPU (CuPy or Torch)
1543
+ try:
1544
+ import cupy as cp
1545
+ gpu_input_cupy = isinstance(X, cp.ndarray) and isinstance(y, cp.ndarray)
1546
+ if sample_weight is not None and not isinstance(sample_weight, cp.ndarray):
1547
+ gpu_input_cupy = False
1548
+ except Exception:
1549
+ pass
1550
+
1551
+ # Also check for torch tensors
1552
+ if not gpu_input_cupy:
1553
+ try:
1554
+ import torch
1555
+ gpu_input_torch = isinstance(X, torch.Tensor) and isinstance(y, torch.Tensor)
1556
+ if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
1557
+ gpu_input_torch = False
1558
+ except Exception:
1559
+ pass
1560
+
1561
+ X_np = None
1562
+ y_np = None
1563
+ sample_weight_np = None
1564
+
1565
+ if gpu_input_cupy or gpu_input_torch:
1566
+ # GPU inputs - get backend for validation
1567
+ backend = get_backend(backend='auto', device='cuda')
1568
+ if len(tuple(X.shape)) != 2:
1569
+ raise ValueError("X must be a 2D array")
1570
+ n_samples = int(X.shape[0])
1571
+ y_check = backend.asarray(y).reshape(-1)
1572
+ if int(y_check.shape[0]) != n_samples:
1573
+ raise ValueError("y must have the same number of rows as X")
1574
+ if sample_weight is not None:
1575
+ sw_check = backend.asarray(sample_weight).reshape(-1)
1576
+ if int(sw_check.shape[0]) != n_samples:
1577
+ raise ValueError("sample_weight must have the same number of rows as X")
1578
+ else:
1579
+ X_np = np.asarray(X, dtype=np.float64)
1580
+ y_np = np.asarray(y, dtype=np.float64).reshape(-1)
1581
+ if sample_weight is not None:
1582
+ sample_weight_np = np.asarray(sample_weight, dtype=np.float64).reshape(-1)
1583
+ if X_np.ndim != 2:
1584
+ raise ValueError("X must be a 2D array")
1585
+ if y_np.shape[0] != X_np.shape[0]:
1586
+ raise ValueError("y must have the same number of rows as X")
1587
+ if sample_weight_np is not None and sample_weight_np.shape[0] != X_np.shape[0]:
1588
+ raise ValueError("sample_weight must have the same number of rows as X")
1589
+ n_samples = int(X_np.shape[0])
1590
+
1591
+ cv_method = _normalize_lassocv_method(method)
1592
+ requested_cd_kkt_check_every = _normalize_cd_kkt_check_every(cd_kkt_check_every)
1593
+
1594
+ if alphas is None:
1595
+ if gpu_input_cupy or gpu_input_torch:
1596
+ # Get backend based on input type
1597
+ if gpu_input_torch:
1598
+ backend = get_backend(backend='torch', device='cuda')
1599
+ else:
1600
+ backend = get_backend(backend='cupy', device='cuda')
1601
+ alpha_grid = _default_lasso_alpha_grid_backend(
1602
+ X,
1603
+ y,
1604
+ backend,
1605
+ n_alphas=n_alphas,
1606
+ alpha_min_ratio=alpha_min_ratio,
1607
+ )
1608
+ else:
1609
+ alpha_grid = _default_lasso_alpha_grid(
1610
+ X_np,
1611
+ y_np,
1612
+ n_alphas=n_alphas,
1613
+ alpha_min_ratio=alpha_min_ratio,
1614
+ )
1615
+ else:
1616
+ alpha_grid = np.asarray(alphas, dtype=np.float64).reshape(-1)
1617
+ alpha_grid = alpha_grid[np.isfinite(alpha_grid)]
1618
+ alpha_grid = alpha_grid[alpha_grid > 0.0]
1619
+ if alpha_grid.size == 0:
1620
+ if gpu_input_cupy or gpu_input_torch:
1621
+ # Get backend based on input type
1622
+ if gpu_input_torch:
1623
+ backend = get_backend(backend='torch', device='cuda')
1624
+ else:
1625
+ backend = get_backend(backend='cupy', device='cuda')
1626
+ alpha_grid = _default_lasso_alpha_grid_backend(
1627
+ X,
1628
+ y,
1629
+ backend,
1630
+ n_alphas=n_alphas,
1631
+ alpha_min_ratio=alpha_min_ratio,
1632
+ )
1633
+ else:
1634
+ alpha_grid = _default_lasso_alpha_grid(
1635
+ X_np,
1636
+ y_np,
1637
+ n_alphas=n_alphas,
1638
+ alpha_min_ratio=alpha_min_ratio,
1639
+ )
1640
+
1641
+ user_folds = _normalize_cv_splits(cv_splits, n_samples=n_samples)
1642
+ effective_n_folds = int(len(user_folds)) if user_folds is not None else int(cv_folds)
1643
+
1644
+ if int(n_samples) < 4 or int(alpha_grid.size) == 1 or int(effective_n_folds) < 2:
1645
+ alpha0 = float(alpha_grid[0])
1646
+ if not return_details:
1647
+ return alpha0
1648
+ return {
1649
+ "alpha": alpha0,
1650
+ "alphas": alpha_grid.astype(np.float64, copy=False),
1651
+ "mse_path": np.full((int(alpha_grid.size), 1), np.nan, dtype=np.float64),
1652
+ "mean_mse": np.full(int(alpha_grid.size), np.nan, dtype=np.float64),
1653
+ }
1654
+
1655
+ if user_folds is not None:
1656
+ folds = user_folds
1657
+ else:
1658
+ folds = _kfold_indices(
1659
+ n_samples=int(n_samples),
1660
+ n_splits=int(cv_folds),
1661
+ random_state=random_state,
1662
+ )
1663
+
1664
+ folds_are_complements = _folds_are_complements(folds, n_samples=int(n_samples))
1665
+
1666
+ alpha_grid = alpha_grid.astype(np.float64, copy=False)
1667
+ n_alpha = int(alpha_grid.size)
1668
+ n_folds = int(len(folds))
1669
+
1670
+ cache_key_eff = cache_key
1671
+ if cache_key_eff is None and _LASSO_CV_ALPHA_CACHE_MAXSIZE > 0:
1672
+ cache_key_eff = _make_lasso_cv_auto_cache_key(
1673
+ X=X,
1674
+ y=y,
1675
+ sample_weight=sample_weight,
1676
+ alpha_grid=alpha_grid,
1677
+ folds=folds,
1678
+ fit_intercept=bool(fit_intercept),
1679
+ use_gpu=bool(use_gpu),
1680
+ max_iter=int(max_iter),
1681
+ tol=float(tol),
1682
+ cpu_solver=str(cpu_solver),
1683
+ cv_method=str(cv_method),
1684
+ cd_kkt_check_every=requested_cd_kkt_check_every,
1685
+ gpu_cv_mixed_precision=bool(gpu_cv_mixed_precision),
1686
+ )
1687
+
1688
+ cached_details = _lasso_cv_cache_get(cache_key_eff)
1689
+ if cached_details is not None:
1690
+ if return_details:
1691
+ return cached_details
1692
+ return float(cached_details["alpha"])
1693
+
1694
+ # Evaluate alpha path in descending order for warm-start efficiency.
1695
+ alpha_order_desc = np.argsort(-alpha_grid)
1696
+ alpha_desc = alpha_grid[alpha_order_desc]
1697
+
1698
+ mse_path = np.full((n_alpha, n_folds), np.nan, dtype=np.float64)
1699
+
1700
+ best_alpha = float(alpha_grid[0])
1701
+ best_mse = float("inf")
1702
+
1703
+ if use_gpu:
1704
+ try:
1705
+ # Get backend based on input type - prefer Torch backend for Torch tensors
1706
+ if gpu_input_torch:
1707
+ backend = get_backend(backend='torch', device='cuda')
1708
+ elif gpu_input_cupy:
1709
+ backend = get_backend(backend='cupy', device='cuda')
1710
+ else:
1711
+ backend = get_backend(backend='auto', device='cuda')
1712
+ xp = backend.xp
1713
+
1714
+ cv_dtype = backend.float32 if bool(gpu_cv_mixed_precision) else backend.float64
1715
+
1716
+ # Convert inputs to backend arrays
1717
+ if gpu_input_cupy or gpu_input_torch:
1718
+ # Already on GPU (CuPy or Torch)
1719
+ X_full = backend.asarray(X, dtype=cv_dtype)
1720
+ y_full = backend.asarray(y, dtype=cv_dtype).reshape(-1)
1721
+ if sample_weight is not None:
1722
+ sw_full = backend.asarray(sample_weight, dtype=cv_dtype).reshape(-1)
1723
+ else:
1724
+ sw_full = None
1725
+ else:
1726
+ # Convert from numpy
1727
+ X_full = backend.asarray(X_np, dtype=cv_dtype)
1728
+ y_full = backend.asarray(y_np, dtype=cv_dtype)
1729
+ if sample_weight_np is not None:
1730
+ sw_full = backend.asarray(sample_weight_np, dtype=cv_dtype)
1731
+ else:
1732
+ sw_full = None
1733
+
1734
+ XtX_folds = []
1735
+ Xty_folds = []
1736
+ n_train_folds = []
1737
+ X_mean_folds = []
1738
+ y_mean_folds = []
1739
+ fold_eval_payload = []
1740
+
1741
+ fast_fold_stats = (sw_full is None) and bool(folds_are_complements)
1742
+ if fast_fold_stats:
1743
+ n_total = int(X_full.shape[0])
1744
+ XtX_full = X_full.T @ X_full
1745
+ Xty_full = X_full.T @ y_full
1746
+ if bool(fit_intercept):
1747
+ X_sum_full = backend.sum(X_full, axis=0)
1748
+ y_sum_full = backend.sum(y_full)
1749
+ else:
1750
+ X_sum_full = None
1751
+ y_sum_full = None
1752
+
1753
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
1754
+ train_idx_gpu = backend.asarray(train_idx)
1755
+ val_idx_gpu = backend.asarray(val_idx)
1756
+
1757
+ X_val = X_full[val_idx_gpu]
1758
+ y_val = y_full[val_idx_gpu]
1759
+ sw_val = None if sw_full is None else sw_full[val_idx_gpu]
1760
+ sw_train = None # initialized per-fold in slow path; None for fast path
1761
+
1762
+ if fast_fold_stats:
1763
+ n_val = int(val_idx_gpu.shape[0])
1764
+ n_train = int(n_total - n_val)
1765
+
1766
+ XtX_val = X_val.T @ X_val
1767
+ Xty_val = X_val.T @ y_val
1768
+ XtX_raw = XtX_full - XtX_val
1769
+ Xty_raw = Xty_full - Xty_val
1770
+
1771
+ if bool(fit_intercept):
1772
+ X_sum_val = backend.sum(X_val, axis=0)
1773
+ y_sum_val = backend.sum(y_val)
1774
+ X_sum_train = X_sum_full - X_sum_val
1775
+ y_sum_train = y_sum_full - y_sum_val
1776
+
1777
+ inv_n = backend.asarray(1.0 / float(max(1, n_train)), dtype=X_full.dtype)
1778
+ X_mean = X_sum_train * inv_n
1779
+ y_mean = y_sum_train * inv_n
1780
+ XtX = XtX_raw - backend.outer(X_sum_train, X_sum_train) * inv_n
1781
+ Xty = Xty_raw - X_sum_train * y_mean
1782
+ else:
1783
+ X_mean = backend.zeros((X_full.shape[1],), dtype=X_full.dtype)
1784
+ y_mean = backend.array(0.0, dtype=X_full.dtype)
1785
+ XtX = XtX_raw
1786
+ Xty = Xty_raw
1787
+ else:
1788
+ X_train = X_full[train_idx_gpu]
1789
+ y_train = y_full[train_idx_gpu]
1790
+ sw_train = None if sw_full is None else sw_full[train_idx_gpu]
1791
+
1792
+ if sw_train is not None:
1793
+ sqrt_sw = backend.sqrt(sw_train)
1794
+ X_train = X_train * sqrt_sw[:, None]
1795
+ y_train = y_train * sqrt_sw
1796
+
1797
+ if bool(fit_intercept):
1798
+ X_mean = backend.mean(X_train, axis=0)
1799
+ y_mean = backend.mean(y_train)
1800
+ X_centered = X_train - X_mean
1801
+ y_centered = y_train - y_mean
1802
+ else:
1803
+ X_mean = backend.zeros((X_train.shape[1],), dtype=X_train.dtype)
1804
+ y_mean = backend.array(0.0, dtype=X_train.dtype)
1805
+ X_centered = X_train
1806
+ y_centered = y_train
1807
+
1808
+ XtX = X_centered.T @ X_centered
1809
+ Xty = X_centered.T @ y_centered
1810
+ # For weighted case, effective sample size is sum(weights)
1811
+ if sw_train is not None:
1812
+ n_train = float(backend.sum(sw_train))
1813
+ else:
1814
+ n_train = int(X_train.shape[0])
1815
+
1816
+ XtX_folds.append(XtX)
1817
+ Xty_folds.append(Xty)
1818
+ n_train_folds.append(float(n_train) if sw_train is not None else int(n_train))
1819
+ X_mean_folds.append(X_mean)
1820
+ y_mean_folds.append(y_mean)
1821
+ fold_eval_payload.append((X_val, y_val, sw_val))
1822
+
1823
+ XtX_batch = backend.stack(XtX_folds, axis=0)
1824
+ Xty_batch = backend.stack(Xty_folds, axis=0)
1825
+
1826
+ # Use native Torch FISTA solver for Torch backend
1827
+ if hasattr(xp, '__name__') and 'torch' in xp.__name__.lower():
1828
+ import torch
1829
+ n_samples_vec_torch = torch.tensor(np.asarray(n_train_folds, dtype=np.int32), device=XtX_batch.device, dtype=XtX_batch.dtype)
1830
+
1831
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
1832
+ XtX_batch,
1833
+ Xty_batch,
1834
+ n_samples_vec=n_samples_vec_torch,
1835
+ alphas_desc=alpha_desc,
1836
+ max_iter=int(max_iter),
1837
+ tol=float(tol),
1838
+ stopping="coef_delta",
1839
+ lipschitz_L=None,
1840
+ check_every=8,
1841
+ )
1842
+
1843
+ # Convert results back to numpy for evaluation
1844
+ for fold_idx in range(int(len(folds))):
1845
+ coefs_desc_np = coefs_batch_desc[fold_idx] # already numpy from the solver
1846
+
1847
+ if bool(fit_intercept):
1848
+ y_mean_val = float(y_mean_folds[fold_idx])
1849
+ X_mean_val = X_mean_folds[fold_idx]
1850
+ intercepts_desc = y_mean_val - X_mean_val @ coefs_desc_np.T
1851
+ intercepts_desc_gpu = backend.asarray(intercepts_desc)
1852
+ coefs_desc_gpu = backend.asarray(coefs_desc_np)
1853
+ else:
1854
+ intercepts_desc_gpu = backend.zeros((coefs_desc_np.shape[0],), dtype=coefs_desc_np.dtype)
1855
+ coefs_desc_gpu = backend.asarray(coefs_desc_np)
1856
+
1857
+ X_val, y_val, sw_val = fold_eval_payload[fold_idx]
1858
+ mse_desc = _batch_mse_cv(X_val, y_val, coefs_desc_gpu, intercepts_desc_gpu, sample_weight=sw_val)
1859
+
1860
+ mse_path[alpha_order_desc, fold_idx] = mse_desc
1861
+ else:
1862
+ # CuPy backend - use existing solver directly
1863
+ import cupy as cp
1864
+ n_samples_vec_cp = cp.asarray(np.asarray(n_train_folds, dtype=np.int32))
1865
+
1866
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
1867
+ XtX_batch,
1868
+ Xty_batch,
1869
+ n_samples_vec=n_samples_vec_cp,
1870
+ alphas_desc=alpha_desc,
1871
+ max_iter=int(max_iter),
1872
+ tol=float(tol),
1873
+ stopping="coef_delta",
1874
+ lipschitz_L=None,
1875
+ check_every=8,
1876
+ )
1877
+
1878
+ for fold_idx in range(int(len(folds))):
1879
+ coefs_desc = coefs_batch_desc[fold_idx]
1880
+
1881
+ if bool(fit_intercept):
1882
+ intercepts_desc = y_mean_folds[fold_idx] - X_mean_folds[fold_idx] @ coefs_desc.T
1883
+ else:
1884
+ intercepts_desc = backend.zeros((coefs_desc.shape[0],), dtype=coefs_desc.dtype)
1885
+
1886
+ X_val, y_val, sw_val = fold_eval_payload[fold_idx]
1887
+ mse_desc = _batch_mse_cv(X_val, y_val, coefs_desc, intercepts_desc, sample_weight=sw_val)
1888
+
1889
+ mse_path[alpha_order_desc, fold_idx] = mse_desc
1890
+
1891
+ except Exception as exc:
1892
+ raise RuntimeError(
1893
+ "GPU path failed in _select_lasso_alpha_cv with device='cuda'; "
1894
+ "CPU fallback is disabled for strict CUDA execution."
1895
+ ) from exc
1896
+
1897
+ if not use_gpu:
1898
+ if gpu_requested:
1899
+ raise RuntimeError(
1900
+ "device='cuda' requested but GPU path was not executed; "
1901
+ "CPU fallback is disabled for strict CUDA execution."
1902
+ )
1903
+ cpu_solver_name = str(cpu_solver).lower()
1904
+
1905
+ if cv_method == "glmnet":
1906
+ # glmnet-like CV profile: coordinate-descent path with periodic full KKT scans.
1907
+ cpu_solver_name = "coordinate_descent"
1908
+
1909
+ if requested_cd_kkt_check_every is None:
1910
+ cd_kkt_check_every_effective = 4 if cv_method == "glmnet" else 1
1911
+ else:
1912
+ cd_kkt_check_every_effective = int(requested_cd_kkt_check_every)
1913
+
1914
+ fast_fold_stats = (sample_weight_np is None) and bool(folds_are_complements)
1915
+ if fast_fold_stats:
1916
+ n_total = int(X_np.shape[0])
1917
+ XtX_full = X_np.T @ X_np
1918
+ Xty_full = X_np.T @ y_np
1919
+ if bool(fit_intercept):
1920
+ X_sum_full = np.sum(X_np, axis=0)
1921
+ y_sum_full = float(np.sum(y_np))
1922
+ else:
1923
+ X_sum_full = None
1924
+ y_sum_full = None
1925
+
1926
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
1927
+ X_val = X_np[val_idx]
1928
+ y_val = y_np[val_idx]
1929
+ sw_val = None if sample_weight_np is None else sample_weight_np[val_idx]
1930
+
1931
+ if fast_fold_stats:
1932
+ n_val = int(np.asarray(val_idx, dtype=np.int64).reshape(-1).size)
1933
+ n_train = int(n_total - n_val)
1934
+
1935
+ XtX_val = X_val.T @ X_val
1936
+ Xty_val = X_val.T @ y_val
1937
+ XtX_raw = XtX_full - XtX_val
1938
+ Xty_raw = Xty_full - Xty_val
1939
+
1940
+ if bool(fit_intercept):
1941
+ X_sum_val = np.sum(X_val, axis=0)
1942
+ y_sum_val = float(np.sum(y_val))
1943
+ X_sum_train = X_sum_full - X_sum_val
1944
+ y_sum_train = y_sum_full - y_sum_val
1945
+
1946
+ inv_n = 1.0 / float(max(1, n_train))
1947
+ X_mean = X_sum_train * inv_n
1948
+ y_mean = y_sum_train * inv_n
1949
+ XtX = XtX_raw - np.outer(X_sum_train, X_sum_train) * inv_n
1950
+ Xty = Xty_raw - X_sum_train * y_mean
1951
+ else:
1952
+ X_mean = np.zeros((X_np.shape[1],), dtype=np.float64)
1953
+ y_mean = 0.0
1954
+ XtX = XtX_raw
1955
+ Xty = Xty_raw
1956
+ else:
1957
+ X_train = X_np[train_idx]
1958
+ y_train = y_np[train_idx]
1959
+ sw_train = None if sample_weight_np is None else sample_weight_np[train_idx]
1960
+
1961
+ if bool(fit_intercept):
1962
+ # Compute weighted means on ORIGINAL data (before sqrt-weighting)
1963
+ if sw_train is not None:
1964
+ sw_sum = float(np.sum(sw_train))
1965
+ X_mean = np.sum(X_train * sw_train[:, np.newaxis], axis=0) / sw_sum
1966
+ y_mean = float(np.sum(y_train * sw_train)) / sw_sum
1967
+ else:
1968
+ X_mean = np.mean(X_train, axis=0)
1969
+ y_mean = float(np.mean(y_train))
1970
+ X_centered = X_train - X_mean
1971
+ y_centered = y_train - y_mean
1972
+ else:
1973
+ X_mean = np.zeros((X_train.shape[1],), dtype=np.float64)
1974
+ y_mean = 0.0
1975
+ X_centered = X_train
1976
+ y_centered = y_train
1977
+
1978
+ # Apply sqrt-weighting after centering
1979
+ if sw_train is not None:
1980
+ sqrt_sw = np.sqrt(sw_train)
1981
+ X_centered = X_centered * sqrt_sw[:, np.newaxis]
1982
+ y_centered = y_centered * sqrt_sw
1983
+
1984
+ XtX = X_centered.T @ X_centered
1985
+ Xty = X_centered.T @ y_centered
1986
+ # Use weight sum as effective sample size for proper alpha scaling
1987
+ n_train = float(np.sum(sw_train)) if sw_train is not None else int(X_train.shape[0])
1988
+
1989
+ coefs_desc, _ = _solve_lasso_path_cpu_from_gram(
1990
+ XtX,
1991
+ Xty,
1992
+ n_samples=int(n_train),
1993
+ alphas_desc=alpha_desc,
1994
+ max_iter=int(max_iter),
1995
+ tol=float(tol),
1996
+ stopping="coef_delta",
1997
+ cpu_solver=cpu_solver_name,
1998
+ lipschitz_L=None,
1999
+ cd_kkt_check_every=cd_kkt_check_every_effective,
2000
+ )
2001
+
2002
+ if bool(fit_intercept):
2003
+ intercepts_desc = y_mean - X_mean @ coefs_desc.T
2004
+ else:
2005
+ intercepts_desc = np.zeros((coefs_desc.shape[0],), dtype=np.float64)
2006
+
2007
+ mse_desc = _batch_mse_cv(
2008
+ X_val,
2009
+ y_val,
2010
+ coefs_desc,
2011
+ intercepts_desc,
2012
+ sample_weight=sw_val,
2013
+ )
2014
+
2015
+ mse_path[alpha_order_desc, fold_idx] = np.asarray(mse_desc, dtype=np.float64)
2016
+
2017
+ for alpha_idx, alpha in enumerate(alpha_grid):
2018
+ alpha_f = float(alpha)
2019
+ valid = np.isfinite(mse_path[alpha_idx])
2020
+ if not bool(np.any(valid)):
2021
+ continue
2022
+
2023
+ mean_mse = float(np.mean(mse_path[alpha_idx, valid]))
2024
+ if mean_mse < best_mse:
2025
+ best_mse = mean_mse
2026
+ best_alpha = alpha_f
2027
+
2028
+ mean_mse_vec = np.full(int(alpha_grid.size), np.nan, dtype=np.float64)
2029
+ for alpha_idx in range(int(alpha_grid.size)):
2030
+ valid = np.isfinite(mse_path[alpha_idx])
2031
+ if bool(np.any(valid)):
2032
+ mean_mse_vec[alpha_idx] = float(np.mean(mse_path[alpha_idx, valid]))
2033
+
2034
+ details = {
2035
+ "alpha": float(best_alpha),
2036
+ "alphas": alpha_grid.astype(np.float64, copy=False),
2037
+ "mse_path": mse_path,
2038
+ "mean_mse": mean_mse_vec,
2039
+ }
2040
+
2041
+ _lasso_cv_cache_put(cache_key_eff, details)
2042
+
2043
+ if return_details:
2044
+ return details
2045
+
2046
+ return float(details["alpha"])
2047
+
2048
+
2049
+ from statgpu.linear_model.penalized._penalized_linear import PenalizedLinearRegression as _PenalizedLinearRegression
2050
+
2051
+
2052
+ class Lasso(_PenalizedLinearRegression):
2053
+ """Thin sklearn-style wrapper over ``PenalizedLinearRegression`` with L1 penalty."""
2054
+
2055
+ def __init__(
2056
+ self,
2057
+ alpha: float = 1.0,
2058
+ fit_intercept: bool = True,
2059
+ max_iter: int = 1000,
2060
+ tol: float = 1e-4,
2061
+ stopping: str = "coef_delta",
2062
+ inference_method: str = "debiased",
2063
+ n_bootstrap: int = 200,
2064
+ bootstrap_random_state: Optional[int] = None,
2065
+ enable_simultaneous_inference: bool = False,
2066
+ simultaneous_method: str = "maxz_bootstrap",
2067
+ simultaneous_alpha: float = 0.05,
2068
+ simultaneous_n_bootstrap: int = 1000,
2069
+ simultaneous_random_state: Optional[int] = None,
2070
+ simultaneous_include_intercept: bool = False,
2071
+ device: Union[str, Device] = Device.AUTO,
2072
+ n_jobs: Optional[int] = None,
2073
+ compute_inference: bool = True,
2074
+ solver: str = "fista",
2075
+ cpu_solver: str = "coordinate_descent",
2076
+ lipschitz_L: Optional[float] = None,
2077
+ admm_rho: float = 1.0,
2078
+ gpu_memory_cleanup: bool = False,
2079
+ ):
2080
+ # Lasso-specific attributes (set before super().__init__ which doesn't know them)
2081
+ self.n_bootstrap = int(n_bootstrap)
2082
+ self.bootstrap_random_state = bootstrap_random_state
2083
+ self.enable_simultaneous_inference = bool(enable_simultaneous_inference)
2084
+ _sm = str(simultaneous_method).lower()
2085
+ self.simultaneous_method = simultaneous_method if simultaneous_method == _sm else _sm
2086
+ self.simultaneous_alpha = float(simultaneous_alpha)
2087
+ self.simultaneous_n_bootstrap = int(simultaneous_n_bootstrap)
2088
+ self.simultaneous_random_state = simultaneous_random_state
2089
+ self.simultaneous_include_intercept = bool(simultaneous_include_intercept)
2090
+ self.admm_rho = float(admm_rho)
2091
+ super().__init__(
2092
+ penalty="l1",
2093
+ alpha=alpha,
2094
+ fit_intercept=fit_intercept,
2095
+ max_iter=max_iter,
2096
+ tol=tol,
2097
+ device=device,
2098
+ n_jobs=n_jobs,
2099
+ cpu_solver=cpu_solver,
2100
+ solver=solver,
2101
+ lipschitz_L=lipschitz_L,
2102
+ gpu_memory_cleanup=gpu_memory_cleanup,
2103
+ compute_inference=compute_inference,
2104
+ stopping=stopping,
2105
+ )
2106
+ # Re-set after super().__init__() which overwrites with parent default
2107
+ _im = str(inference_method).lower()
2108
+ self.inference_method = inference_method if inference_method == _im else _im
2109
+
2110
+ # Validate simultaneous inference settings
2111
+ if self.enable_simultaneous_inference:
2112
+ if self.simultaneous_method != "maxz_bootstrap":
2113
+ raise ValueError(
2114
+ f"simultaneous_method must be 'maxz_bootstrap', "
2115
+ f"got '{self.simultaneous_method}'"
2116
+ )
2117
+ if "debiased" not in self.inference_method:
2118
+ raise ValueError(
2119
+ "Simultaneous inference requires inference_method='debiased'."
2120
+ )
2121
+ if not self.compute_inference:
2122
+ raise ValueError(
2123
+ "Simultaneous inference requires compute_inference=True."
2124
+ )