statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2699 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified cross-validated penalized GLM estimator.
|
|
3
|
+
|
|
4
|
+
Supports all GLM loss functions (squared_error, logistic, poisson, gamma,
|
|
5
|
+
inverse_gaussian, negative_binomial, tweedie) with all penalty types
|
|
6
|
+
(l1, l2, elasticnet, scad, mcp, adaptive_l1, group_lasso).
|
|
7
|
+
|
|
8
|
+
Optimizations:
|
|
9
|
+
- Warm-start across alpha values (descending order)
|
|
10
|
+
- Batch eigendecomposition for squared_error + l2 (CPU/CuPy/Torch)
|
|
11
|
+
- Precomputed loss function and cached validation data per fold
|
|
12
|
+
- Minimal D2H transfers
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
__all__ = ["PenalizedGLM_CV"]
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
from typing import Optional, Union
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from statgpu._config import Device
|
|
28
|
+
from statgpu.backends import _to_numpy
|
|
29
|
+
from statgpu.backends._array_ops import _copy_arr, _zeros, _xp_zeros, _soft_threshold
|
|
30
|
+
from statgpu.backends._utils import _to_float_scalar
|
|
31
|
+
from statgpu.cross_validation._base import CVEstimatorBase, kfold_indices
|
|
32
|
+
from statgpu.solvers._utils import _nesterov_momentum
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ---------------------------------------------------------------------------
|
|
36
|
+
# Numerical constants for GLM loss computation (shared across CV paths)
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
|
|
39
|
+
# Eta clipping bounds (prevents overflow in exp/link functions)
|
|
40
|
+
_ETA_CLIP_STANDARD = 30.0 # Poisson, Gamma, NB, InvGauss
|
|
41
|
+
_ETA_CLIP_TWEEDIE = 50.0 # Tweedie (wider range for mu^p stability)
|
|
42
|
+
_ETA_CLIP_LOGISTIC = 500.0 # Logistic (sigmoid saturates, safe range)
|
|
43
|
+
|
|
44
|
+
# Mu clipping bounds (prevents division by zero / log(0))
|
|
45
|
+
_MU_LO = 1e-10 # Standard lower bound for mu
|
|
46
|
+
_MU_LO_TWEEDIE = 1e-3 # Tweedie lower bound
|
|
47
|
+
_MU_HI_TWEEDIE = 1e4 # Tweedie upper bound
|
|
48
|
+
_MU_LO_INVGAUSS = 5e-2 # Inverse Gaussian lower bound
|
|
49
|
+
_MU_HI_INVGAUSS = 1e3 # Inverse Gaussian upper bound
|
|
50
|
+
_MU_LO_NB = 1e-300 # Negative binomial lower bound
|
|
51
|
+
|
|
52
|
+
# Default loss parameters (must match loss object defaults)
|
|
53
|
+
_NB_ALPHA_DEFAULT = 1.0 # NegativeBinomialLoss default alpha
|
|
54
|
+
_TWEEDIE_POWER_DEFAULT = 1.5 # TweedieLoss default power
|
|
55
|
+
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
# CV solver tuning constants
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
|
|
60
|
+
# Eigenvalue floor (prevents division by zero in Lipschitz computation)
|
|
61
|
+
_EIGVAL_FLOOR = 1e-15
|
|
62
|
+
|
|
63
|
+
# FISTA iteration caps for CV (lower than full fit to keep CV fast)
|
|
64
|
+
_FISTA_MAX_ITER_CV = 400 # Default max FISTA iterations per alpha in CV
|
|
65
|
+
_FISTA_MAX_ITER_CV_SMALL = 600 # For small problems (n*p < _SMALL_PROBLEM_THRESHOLD)
|
|
66
|
+
|
|
67
|
+
# Convergence check intervals (sync cost vs responsiveness tradeoff)
|
|
68
|
+
_CONV_INTERVAL_CV_DEFAULT = 200 # Default convergence check interval
|
|
69
|
+
_CONV_INTERVAL_CV_TIGHT = 30 # Tighter interval for first few alphas
|
|
70
|
+
_CONV_INTERVAL_CV_FOLD = 50 # Per-fold convergence interval
|
|
71
|
+
_CONV_INTERVAL_CV_PATH = 25 # Path-based convergence interval
|
|
72
|
+
_CONV_INTERVAL_CV_NUMPY = 10 # Numpy path (no sync cost)
|
|
73
|
+
|
|
74
|
+
# Problem size thresholds
|
|
75
|
+
_SMALL_PROBLEM_THRESHOLD = 200_000 # n*p below this = "small problem"
|
|
76
|
+
_GPU_BREAK_EVEN_THRESHOLD = 100_000_000 # CV work below this = CPU faster
|
|
77
|
+
|
|
78
|
+
# IRLS deviance tolerance constants
|
|
79
|
+
_IRLS_DEV_TOL_REL = 1e-10 # Relative deviance tolerance
|
|
80
|
+
_IRLS_DEV_TOL_ABS = 1e-6 # Absolute deviance tolerance floor
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ApproximateCVWarning(UserWarning):
|
|
84
|
+
"""Warning emitted when approximate two-stage CV screening is enabled."""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _is_uniform_weight(sample_weight) -> bool:
|
|
88
|
+
"""Check if sample_weight is uniform (all elements equal) or None."""
|
|
89
|
+
if sample_weight is None:
|
|
90
|
+
return True
|
|
91
|
+
sw_np = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
|
|
92
|
+
return not sw_np.size or np.allclose(sw_np, sw_np[0])
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _device_to_name(device):
|
|
96
|
+
if isinstance(device, Device):
|
|
97
|
+
return device.value
|
|
98
|
+
return str(device).lower()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _slice_rows(arr, idx):
|
|
102
|
+
"""Slice rows with backend-native indices when arr lives on GPU."""
|
|
103
|
+
mod = type(arr).__module__
|
|
104
|
+
if mod.startswith("cupy"):
|
|
105
|
+
import cupy as cp
|
|
106
|
+
return arr[cp.asarray(idx)]
|
|
107
|
+
if mod.startswith("torch"):
|
|
108
|
+
import torch
|
|
109
|
+
return arr[torch.as_tensor(idx, dtype=torch.long, device=arr.device)]
|
|
110
|
+
try:
|
|
111
|
+
return arr[idx]
|
|
112
|
+
except TypeError:
|
|
113
|
+
return np.asarray(arr)[idx]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _nanargmin_prefer_larger_alpha(scores, alpha_grid, rel_tol=1e-10, abs_tol=1e-12):
|
|
117
|
+
"""Select min score with deterministic tie-break toward stronger regularization."""
|
|
118
|
+
scores = np.asarray(scores, dtype=np.float64)
|
|
119
|
+
alpha_grid = np.asarray(alpha_grid, dtype=np.float64)
|
|
120
|
+
finite = np.isfinite(scores)
|
|
121
|
+
if not np.any(finite):
|
|
122
|
+
# All scores are NaN/Inf — fall back to first alpha (strongest regularization)
|
|
123
|
+
warnings.warn("All CV scores are NaN/Inf; returning first alpha.", stacklevel=2)
|
|
124
|
+
return 0
|
|
125
|
+
best = float(np.nanmin(scores))
|
|
126
|
+
tol = max(float(abs_tol), abs(best) * float(rel_tol))
|
|
127
|
+
candidates = np.flatnonzero(finite & (scores <= best + tol))
|
|
128
|
+
return int(candidates[np.argmax(alpha_grid[candidates])])
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _two_stage_candidate_mask(scores, refine_top_k=3):
|
|
132
|
+
"""Return alpha candidates to strictly refine after approximate screening."""
|
|
133
|
+
scores = np.asarray(scores, dtype=np.float64).ravel()
|
|
134
|
+
n_scores = scores.size
|
|
135
|
+
mask = np.zeros(n_scores, dtype=bool)
|
|
136
|
+
finite = np.isfinite(scores)
|
|
137
|
+
if n_scores == 0:
|
|
138
|
+
return mask
|
|
139
|
+
if not np.any(finite):
|
|
140
|
+
warnings.warn("All approximate CV scores are NaN; refining all candidates.", stacklevel=2)
|
|
141
|
+
mask[:] = True
|
|
142
|
+
return mask
|
|
143
|
+
|
|
144
|
+
# Endpoint alphas are common optima on flat or monotone CV curves. Always
|
|
145
|
+
# refine them so approximate screening cannot drop boundary solutions.
|
|
146
|
+
mask[0] = True
|
|
147
|
+
mask[-1] = True
|
|
148
|
+
|
|
149
|
+
k = min(max(1, int(refine_top_k)), int(np.count_nonzero(finite)))
|
|
150
|
+
ranked = np.argsort(np.where(finite, scores, np.inf))[:k]
|
|
151
|
+
for idx in ranked:
|
|
152
|
+
lo = max(0, int(idx) - 1)
|
|
153
|
+
hi = min(n_scores, int(idx) + 2)
|
|
154
|
+
mask[lo:hi] = True
|
|
155
|
+
|
|
156
|
+
best = float(np.nanmin(scores))
|
|
157
|
+
near_tol = max(abs(best) * 0.005, 1e-6)
|
|
158
|
+
mask |= finite & (scores <= best + near_tol)
|
|
159
|
+
return mask
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# ---------------------------------------------------------------------------
|
|
163
|
+
# Per-sample loss function for squared_error (unique signature: needs X_design)
|
|
164
|
+
# ---------------------------------------------------------------------------
|
|
165
|
+
def _ps_squared_error(eta, y, X_design=None, coef_with_intercept=None, **_):
|
|
166
|
+
return (y - X_design @ coef_with_intercept) ** 2
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# loss_name -> (per_sample_fn, uses_design)
|
|
170
|
+
# uses_design=True: fn needs X_design and coef_with_intercept (squared_error)
|
|
171
|
+
# uses_design=False: fn uses eta directly (all GLM losses)
|
|
172
|
+
# Populated below after the loss formula registry functions are defined.
|
|
173
|
+
_LOSS_EVAL_DISPATCH = {}
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _weighted_mean(per_sample, sw):
|
|
177
|
+
"""Compute weighted or unweighted mean of per-sample values."""
|
|
178
|
+
if sw is not None:
|
|
179
|
+
w_sum = float(np.sum(sw))
|
|
180
|
+
if w_sum <= 0:
|
|
181
|
+
return float(np.mean(per_sample))
|
|
182
|
+
return float(np.dot(sw, per_sample) / w_sum)
|
|
183
|
+
return float(np.mean(per_sample))
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _evaluate_loss_numpy(loss_name, loss_fn, X_val_np, y_val_np, coef_np, intercept, fit_intercept, sample_weight=None):
|
|
187
|
+
"""Backend-independent validation loss in float64 numpy.
|
|
188
|
+
|
|
189
|
+
When sample_weight is provided, returns weighted mean loss.
|
|
190
|
+
"""
|
|
191
|
+
coef_np = np.asarray(coef_np, dtype=np.float64).ravel()
|
|
192
|
+
sw = np.asarray(sample_weight, dtype=np.float64).ravel() if sample_weight is not None else None
|
|
193
|
+
|
|
194
|
+
entry = _LOSS_EVAL_DISPATCH.get(loss_name)
|
|
195
|
+
if entry is not None:
|
|
196
|
+
# Resolve loss-specific parameters from loss_fn object
|
|
197
|
+
_loss_params = {}
|
|
198
|
+
if loss_name == "negative_binomial":
|
|
199
|
+
_loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
|
|
200
|
+
elif loss_name == "tweedie":
|
|
201
|
+
_loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
|
|
202
|
+
|
|
203
|
+
per_sample_fn, uses_design = entry
|
|
204
|
+
if uses_design:
|
|
205
|
+
n_val = X_val_np.shape[0]
|
|
206
|
+
if fit_intercept:
|
|
207
|
+
X_design = np.column_stack([np.ones(n_val), X_val_np])
|
|
208
|
+
coef_with_intercept = np.concatenate([[float(intercept)], coef_np])
|
|
209
|
+
else:
|
|
210
|
+
X_design = X_val_np
|
|
211
|
+
coef_with_intercept = coef_np
|
|
212
|
+
eta = X_val_np @ coef_np + (float(intercept) if fit_intercept else 0.0)
|
|
213
|
+
per_sample = per_sample_fn(eta, y_val_np, X_design=X_design, coef_with_intercept=coef_with_intercept, **_loss_params)
|
|
214
|
+
else:
|
|
215
|
+
eta = X_val_np @ coef_np + (float(intercept) if fit_intercept else 0.0)
|
|
216
|
+
per_sample = per_sample_fn(eta, y_val_np, **_loss_params)
|
|
217
|
+
return _weighted_mean(per_sample, sw)
|
|
218
|
+
|
|
219
|
+
# Fallback for unknown loss types
|
|
220
|
+
n_val = X_val_np.shape[0]
|
|
221
|
+
if fit_intercept:
|
|
222
|
+
X_design = np.column_stack([np.ones(n_val), X_val_np])
|
|
223
|
+
coef_with_intercept = np.concatenate([[float(intercept)], coef_np])
|
|
224
|
+
else:
|
|
225
|
+
X_design = X_val_np
|
|
226
|
+
coef_with_intercept = coef_np
|
|
227
|
+
# Fallback: unweighted loss. Weighted mean cannot be derived from
|
|
228
|
+
# unweighted mean, so weights are ignored for unknown loss types.
|
|
229
|
+
if sw is not None:
|
|
230
|
+
import warnings
|
|
231
|
+
warnings.warn(
|
|
232
|
+
f"_evaluate_loss_numpy: loss '{loss_name}' not in dispatch table, "
|
|
233
|
+
f"falling back to unweighted loss_fn.value(). Sample weights ignored.",
|
|
234
|
+
RuntimeWarning,
|
|
235
|
+
stacklevel=2,
|
|
236
|
+
)
|
|
237
|
+
return float(loss_fn.value(X_design, y_val_np, coef_with_intercept))
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _ridge_eig_batch(X_train_np, y_train_np, X_val_np, y_val_np, alphas_np):
|
|
241
|
+
"""Batch Ridge solve via eigendecomposition on numpy.
|
|
242
|
+
|
|
243
|
+
Returns (mse_array, coefs_matrix, intercepts_array).
|
|
244
|
+
All computation in float64 numpy for maximum precision.
|
|
245
|
+
"""
|
|
246
|
+
n, p = X_train_np.shape
|
|
247
|
+
n_alphas = len(alphas_np)
|
|
248
|
+
|
|
249
|
+
X_mean = np.mean(X_train_np, axis=0)
|
|
250
|
+
y_mean = np.mean(y_train_np)
|
|
251
|
+
Xc = X_train_np - X_mean
|
|
252
|
+
yc = y_train_np - y_mean
|
|
253
|
+
|
|
254
|
+
XtX = Xc.T @ Xc
|
|
255
|
+
eigvals, Q = np.linalg.eigh(XtX)
|
|
256
|
+
eigvals = np.maximum(eigvals, _EIGVAL_FLOOR)
|
|
257
|
+
|
|
258
|
+
QtXty = Q.T @ (Xc.T @ yc)
|
|
259
|
+
n_alpha = n * alphas_np
|
|
260
|
+
inv_diag = 1.0 / (eigvals[:, None] + n_alpha[None, :])
|
|
261
|
+
coefs = Q @ (inv_diag * QtXty[:, None])
|
|
262
|
+
intercepts = y_mean - X_mean @ coefs
|
|
263
|
+
|
|
264
|
+
# Predict: X_val @ coef + intercept (intercept already includes -X_mean @ coef)
|
|
265
|
+
y_pred = X_val_np @ coefs + intercepts[None, :]
|
|
266
|
+
mse = np.mean((y_val_np[:, None] - y_pred) ** 2, axis=0)
|
|
267
|
+
|
|
268
|
+
return mse, coefs, intercepts
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _ridge_eig_single(X_train_np, y_train_np, alpha, sample_weight=None):
|
|
272
|
+
"""Single Ridge solve via eigendecomposition. Returns (coef, intercept).
|
|
273
|
+
|
|
274
|
+
When sample_weight is provided, uses weighted centering and weighted
|
|
275
|
+
normal equations: X'WX coef = X'Wy, solved via eigendecomposition of
|
|
276
|
+
X'WX. Same O(p³) cost as unweighted path.
|
|
277
|
+
"""
|
|
278
|
+
n, p = X_train_np.shape
|
|
279
|
+
if sample_weight is not None:
|
|
280
|
+
w = np.asarray(sample_weight, dtype=np.float64).ravel()
|
|
281
|
+
w_sum = w.sum()
|
|
282
|
+
X_mean = np.average(X_train_np, axis=0, weights=w)
|
|
283
|
+
y_mean = float(np.average(y_train_np, weights=w))
|
|
284
|
+
Xc = X_train_np - X_mean
|
|
285
|
+
yc = y_train_np - y_mean
|
|
286
|
+
# Weighted normal equations: Xc' diag(w) Xc
|
|
287
|
+
W_sqrt_Xc = Xc * np.sqrt(w)[:, None]
|
|
288
|
+
XtWX = W_sqrt_Xc.T @ W_sqrt_Xc
|
|
289
|
+
XtWy = (Xc * w[:, None]).T @ yc
|
|
290
|
+
eigvals, Q = np.linalg.eigh(XtWX)
|
|
291
|
+
eigvals = np.maximum(eigvals, _EIGVAL_FLOOR)
|
|
292
|
+
QtXtWy = Q.T @ XtWy
|
|
293
|
+
inv_diag = 1.0 / (eigvals + w_sum * alpha)
|
|
294
|
+
coef = Q @ (inv_diag * QtXtWy)
|
|
295
|
+
intercept = float(y_mean - X_mean @ coef)
|
|
296
|
+
return coef, intercept
|
|
297
|
+
X_mean = np.mean(X_train_np, axis=0)
|
|
298
|
+
y_mean = np.mean(y_train_np)
|
|
299
|
+
Xc = X_train_np - X_mean
|
|
300
|
+
yc = y_train_np - y_mean
|
|
301
|
+
|
|
302
|
+
XtX = Xc.T @ Xc
|
|
303
|
+
eigvals, Q = np.linalg.eigh(XtX)
|
|
304
|
+
eigvals = np.maximum(eigvals, _EIGVAL_FLOOR)
|
|
305
|
+
|
|
306
|
+
QtXty = Q.T @ (Xc.T @ yc)
|
|
307
|
+
inv_diag = 1.0 / (eigvals + n * alpha)
|
|
308
|
+
coef = Q @ (inv_diag * QtXty)
|
|
309
|
+
intercept = float(y_mean - X_mean @ coef)
|
|
310
|
+
return coef, intercept
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _backend_name_for_cv_device(device):
|
|
314
|
+
name = _device_to_name(device)
|
|
315
|
+
if name == "cuda":
|
|
316
|
+
return "cupy"
|
|
317
|
+
if name == "torch":
|
|
318
|
+
return "torch"
|
|
319
|
+
return "numpy"
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# Import shared utility from _cv_base
|
|
323
|
+
from statgpu.cross_validation._base import _torch_cuda_available
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _logistic_sparse_effective_max_iter(max_iter, device, penalty_name, refit=False):
|
|
327
|
+
backend = _backend_name_for_cv_device(device)
|
|
328
|
+
penalty_name = str(penalty_name).lower()
|
|
329
|
+
if backend in ("cupy", "torch") and not refit:
|
|
330
|
+
if penalty_name == "l1":
|
|
331
|
+
return min(int(max_iter), _FISTA_MAX_ITER_CV)
|
|
332
|
+
if penalty_name in ("elasticnet", "en"):
|
|
333
|
+
return min(int(max_iter), _FISTA_MAX_ITER_CV_SMALL)
|
|
334
|
+
return int(max_iter)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _glm_cv_effective_max_iter(max_iter, loss_name, penalty_name, device, refit=False):
|
|
338
|
+
"""CV-only iteration caps for GPU paths whose alpha ranking stabilizes early."""
|
|
339
|
+
backend = _backend_name_for_cv_device(device)
|
|
340
|
+
loss_name = str(loss_name).lower()
|
|
341
|
+
penalty_name = str(penalty_name).lower()
|
|
342
|
+
if backend in ("cupy", "torch") and not refit:
|
|
343
|
+
if loss_name == "tweedie" and penalty_name in ("l1", "elasticnet", "en"):
|
|
344
|
+
return min(int(max_iter), 200)
|
|
345
|
+
if backend == "cupy" and not refit:
|
|
346
|
+
if loss_name == "negative_binomial" and penalty_name == "l2":
|
|
347
|
+
return min(int(max_iter), 30)
|
|
348
|
+
return int(max_iter)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def _to_backend_float64(arr, backend):
|
|
352
|
+
if backend == "cupy":
|
|
353
|
+
import cupy as cp
|
|
354
|
+
return cp.asarray(arr, dtype=cp.float64)
|
|
355
|
+
if backend == "torch":
|
|
356
|
+
import torch
|
|
357
|
+
if isinstance(arr, torch.Tensor):
|
|
358
|
+
# Preserve existing device, just cast dtype
|
|
359
|
+
return arr.to(dtype=torch.float64)
|
|
360
|
+
# Numpy -> torch on current CUDA device
|
|
361
|
+
_dev = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
|
|
362
|
+
return torch.as_tensor(np.asarray(arr, dtype=np.float64), dtype=torch.float64, device=_dev)
|
|
363
|
+
return np.asarray(arr, dtype=np.float64)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
# ---------------------------------------------------------------------------
|
|
367
|
+
# Unified fold-batched CV framework
|
|
368
|
+
# ---------------------------------------------------------------------------
|
|
369
|
+
|
|
370
|
+
def _fb_ones(shape, dtype, is_torch, device=None):
|
|
371
|
+
"""Create ones tensor on the appropriate backend."""
|
|
372
|
+
if is_torch:
|
|
373
|
+
import torch
|
|
374
|
+
return torch.ones(shape, dtype=dtype, device=device)
|
|
375
|
+
import cupy as cp
|
|
376
|
+
return cp.ones(shape, dtype=dtype)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _fb_zeros(shape, dtype, is_torch, device=None):
|
|
380
|
+
"""Create zeros tensor on the appropriate backend."""
|
|
381
|
+
if is_torch:
|
|
382
|
+
import torch
|
|
383
|
+
return torch.zeros(shape, dtype=dtype, device=device)
|
|
384
|
+
import cupy as cp
|
|
385
|
+
return cp.zeros(shape, dtype=dtype)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _fb_as_tensor(arr, is_torch, device=None):
|
|
389
|
+
"""Convert numpy array to int64 backend tensor (for index arrays)."""
|
|
390
|
+
arr_i64 = np.asarray(arr, dtype=np.int64)
|
|
391
|
+
if is_torch:
|
|
392
|
+
import torch
|
|
393
|
+
return torch.as_tensor(arr_i64, dtype=torch.long, device=device)
|
|
394
|
+
import cupy as cp
|
|
395
|
+
return cp.asarray(arr_i64)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _fb_copy(x, is_torch):
|
|
399
|
+
"""Copy a backend tensor."""
|
|
400
|
+
return x.clone() if is_torch else x.copy()
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _fb_cat(tensors, is_torch, dim=1):
|
|
404
|
+
"""Concatenate tensors along dim."""
|
|
405
|
+
if is_torch:
|
|
406
|
+
import torch
|
|
407
|
+
return torch.cat(tensors, dim=dim)
|
|
408
|
+
import cupy as cp
|
|
409
|
+
return cp.concatenate(tensors, axis=dim)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def _fb_sum(x, is_torch, axis=0, keepdims=False):
|
|
413
|
+
"""Sum along axis."""
|
|
414
|
+
if is_torch:
|
|
415
|
+
return x.sum(dim=axis, keepdim=keepdims)
|
|
416
|
+
return x.sum(axis=axis, keepdims=keepdims)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _fb_stack(arrays, is_torch, dim=1):
|
|
420
|
+
"""Stack arrays along dim."""
|
|
421
|
+
if is_torch:
|
|
422
|
+
import torch
|
|
423
|
+
return torch.stack(arrays, dim=dim)
|
|
424
|
+
import cupy as cp
|
|
425
|
+
return cp.stack(arrays, axis=dim)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _fold_batch_lipschitz_logistic(X_aug, y_train, n_train, is_torch):
|
|
429
|
+
eig_max = _max_eigval_power(X_aug.T @ X_aug)
|
|
430
|
+
return max(eig_max / (4.0 * max(n_train, 1)), 1e-12)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _fold_batch_lipschitz_exp_link(X_aug, y_train, n_train, is_torch):
|
|
434
|
+
"""Lipschitz for log-link GLMs (Poisson, Gamma, NB, InvGauss, Tweedie).
|
|
435
|
+
Uses y-scaling: max(1, y_mean, sqrt(y_mean * y_max))."""
|
|
436
|
+
eig_max = _max_eigval_power(X_aug.T @ X_aug)
|
|
437
|
+
if is_torch:
|
|
438
|
+
import torch
|
|
439
|
+
y_mean = float(y_train.mean().item())
|
|
440
|
+
y_max = float(y_train.max().item())
|
|
441
|
+
else:
|
|
442
|
+
import cupy as cp
|
|
443
|
+
y_mean = float(y_train.mean())
|
|
444
|
+
y_max = float(y_train.max())
|
|
445
|
+
y_scale = max(1.0, y_mean, np.sqrt(y_mean * max(y_max, 1e-10)))
|
|
446
|
+
return max(eig_max / max(n_train, 1), 1e-12) * y_scale
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def _fold_batch_lipschitz_gamma(X_aug, y_train, n_train, is_torch):
|
|
450
|
+
"""Lipschitz for Gamma log-link: eig_max(X'X)/n * max(y/y_mean).
|
|
451
|
+
|
|
452
|
+
Differs from _fold_batch_lipschitz_exp_link because Gamma's Hessian
|
|
453
|
+
weights are y/mu (not mu), so scaling uses y-ratio instead of y-moment.
|
|
454
|
+
"""
|
|
455
|
+
eig_max = _max_eigval_power(X_aug.T @ X_aug)
|
|
456
|
+
if is_torch:
|
|
457
|
+
import torch
|
|
458
|
+
y_mean = float(y_train.mean().item())
|
|
459
|
+
y_ratio_max = float((y_train / y_mean).max().item()) if y_mean > 0 else 1.0
|
|
460
|
+
else:
|
|
461
|
+
import cupy as cp
|
|
462
|
+
y_mean = float(y_train.mean())
|
|
463
|
+
y_ratio_max = float((y_train / y_mean).max()) if y_mean > 0 else 1.0
|
|
464
|
+
return max(eig_max / max(n_train, 1), 1e-12) * max(1.0, y_ratio_max)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
# Loss-specific configs: lipschitz_fn and intercept_fn only.
|
|
468
|
+
# ---------------------------------------------------------------------------
|
|
469
|
+
# Loss formula registry — single source of truth for residual and val_loss
|
|
470
|
+
# ---------------------------------------------------------------------------
|
|
471
|
+
# Each loss registers (residual_fn, val_loss_fn) that work with any backend
|
|
472
|
+
# (numpy/torch/cupy) via elementwise ops. The FISTA hot loop calls these
|
|
473
|
+
# instead of inline if/elif chains, eliminating formula duplication.
|
|
474
|
+
#
|
|
475
|
+
# Signature: fn(eta, y, **params) -> per_sample_loss_or_residual
|
|
476
|
+
# `eta` and `y` are backend arrays; `params` carries loss-specific scalars.
|
|
477
|
+
|
|
478
|
+
# Use backend-agnostic utilities from statgpu.backends._array_ops
|
|
479
|
+
# Must be imported before loss function definitions so _res_logistic etc.
|
|
480
|
+
# can use _sigmoid and _softplus.
|
|
481
|
+
from statgpu.backends._array_ops import (
|
|
482
|
+
_clip as _safe_clip,
|
|
483
|
+
_xp as _get_xp,
|
|
484
|
+
_sigmoid,
|
|
485
|
+
_softplus,
|
|
486
|
+
_abs_sum_dev,
|
|
487
|
+
_device_gt,
|
|
488
|
+
_max_eigval_power,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
_LOSS_RESIDUAL_FNS = {}
|
|
493
|
+
_LOSS_VALLOSS_FNS = {}
|
|
494
|
+
|
|
495
|
+
def _register_loss_fns(loss_name, residual_fn, val_loss_fn):
|
|
496
|
+
"""Register per-sample residual (gradient) and validation loss functions for a loss."""
|
|
497
|
+
_LOSS_RESIDUAL_FNS[loss_name] = residual_fn
|
|
498
|
+
_LOSS_VALLOSS_FNS[loss_name] = val_loss_fn
|
|
499
|
+
|
|
500
|
+
# --- Logistic ---
|
|
501
|
+
def _res_logistic(eta, y, **_):
|
|
502
|
+
# Gradient of logistic loss: sigmoid(eta) - y
|
|
503
|
+
return _sigmoid(eta) - y
|
|
504
|
+
|
|
505
|
+
def _val_logistic(eta, y, **_):
|
|
506
|
+
# Logistic loss: -y*eta + softplus(eta)
|
|
507
|
+
return -y * eta + _softplus(eta)
|
|
508
|
+
|
|
509
|
+
# --- Poisson ---
|
|
510
|
+
def _res_poisson(eta, y, **_):
|
|
511
|
+
# Gradient of Poisson loss: d/deta [mu - y*log(mu)] = mu - y
|
|
512
|
+
xp = _get_xp(eta)
|
|
513
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
514
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
515
|
+
return mu_c - y
|
|
516
|
+
|
|
517
|
+
def _val_poisson(eta, y, **_):
|
|
518
|
+
xp = _get_xp(eta)
|
|
519
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
520
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
521
|
+
return mu_c - y * xp.log(mu_c)
|
|
522
|
+
|
|
523
|
+
# --- Gamma ---
|
|
524
|
+
def _res_gamma(eta, y, **_):
|
|
525
|
+
xp = _get_xp(eta)
|
|
526
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
527
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
528
|
+
return 1.0 - y / mu_c
|
|
529
|
+
|
|
530
|
+
def _val_gamma(eta, y, **_):
|
|
531
|
+
xp = _get_xp(eta)
|
|
532
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
533
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
534
|
+
return y / mu_c + xp.log(mu_c)
|
|
535
|
+
|
|
536
|
+
# --- Inverse Gaussian ---
|
|
537
|
+
def _res_invgauss(eta, y, **_):
|
|
538
|
+
xp = _get_xp(eta)
|
|
539
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
540
|
+
# Clip mu^2 (not mu) to avoid denom as small as 1e-20 when mu ~ 1e-10
|
|
541
|
+
mu_sq_c = _safe_clip(mu * mu, _MU_LO, None)
|
|
542
|
+
return (mu - y) / mu_sq_c
|
|
543
|
+
|
|
544
|
+
def _val_invgauss(eta, y, **_):
|
|
545
|
+
# Inverse Gaussian loss: y/(2*mu^2) - 1/mu
|
|
546
|
+
xp = _get_xp(eta)
|
|
547
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
548
|
+
# Clip mu^2 (not mu) to match _ps_inverse_gaussian: denom >= 2e-10
|
|
549
|
+
mu_sq_c = _safe_clip(mu * mu, _MU_LO, None)
|
|
550
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
551
|
+
return y / (2.0 * mu_sq_c) - 1.0 / mu_c
|
|
552
|
+
|
|
553
|
+
# --- Negative Binomial ---
|
|
554
|
+
def _res_nb(eta, y, alpha=_NB_ALPHA_DEFAULT, **_):
|
|
555
|
+
# Gradient of NB loss: d/deta L = (mu - y) / (1 + alpha*mu)
|
|
556
|
+
xp = _get_xp(eta)
|
|
557
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
558
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
559
|
+
return (mu_c - y) / (1.0 + alpha * mu_c)
|
|
560
|
+
|
|
561
|
+
def _val_nb(eta, y, alpha=_NB_ALPHA_DEFAULT, **_):
|
|
562
|
+
xp = _get_xp(eta)
|
|
563
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_STANDARD, _ETA_CLIP_STANDARD))
|
|
564
|
+
mu_c = _safe_clip(mu, _MU_LO, None)
|
|
565
|
+
one_plus = 1.0 + alpha * mu_c
|
|
566
|
+
return -y * xp.log(mu_c / one_plus) + (1.0 / alpha) * xp.log(one_plus)
|
|
567
|
+
|
|
568
|
+
# --- Tweedie ---
|
|
569
|
+
def _res_tweedie(eta, y, power=_TWEEDIE_POWER_DEFAULT, **_):
|
|
570
|
+
xp = _get_xp(eta)
|
|
571
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_TWEEDIE, _ETA_CLIP_TWEEDIE))
|
|
572
|
+
mu_c = _safe_clip(mu, _MU_LO_TWEEDIE, _MU_HI_TWEEDIE)
|
|
573
|
+
return xp.exp((1 - power) * xp.log(mu_c)) * (mu_c - y)
|
|
574
|
+
|
|
575
|
+
def _val_tweedie(eta, y, power=_TWEEDIE_POWER_DEFAULT, **_):
|
|
576
|
+
# Tweedie loss: -y*mu^(1-p)/(1-p) + mu^(2-p)/(2-p)
|
|
577
|
+
# Handle boundary: power=1 (Poisson) and power=2 (Gamma) use log form.
|
|
578
|
+
xp = _get_xp(eta)
|
|
579
|
+
mu = xp.exp(_safe_clip(eta, -_ETA_CLIP_TWEEDIE, _ETA_CLIP_TWEEDIE))
|
|
580
|
+
mu_c = _safe_clip(mu, _MU_LO_TWEEDIE, _MU_HI_TWEEDIE)
|
|
581
|
+
log_mu = xp.log(mu_c)
|
|
582
|
+
d1 = 1.0 - power
|
|
583
|
+
d2 = 2.0 - power
|
|
584
|
+
term1 = -y * xp.exp(d1 * log_mu) / d1 if abs(d1) > 1e-10 else -y * log_mu
|
|
585
|
+
term2 = xp.exp(d2 * log_mu) / d2 if abs(d2) > 1e-10 else log_mu
|
|
586
|
+
return term1 + term2
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
_register_loss_fns("logistic", _res_logistic, _val_logistic)
|
|
590
|
+
_register_loss_fns("poisson", _res_poisson, _val_poisson)
|
|
591
|
+
_register_loss_fns("gamma", _res_gamma, _val_gamma)
|
|
592
|
+
_register_loss_fns("inverse_gaussian", _res_invgauss, _val_invgauss)
|
|
593
|
+
_register_loss_fns("negative_binomial", _res_nb, _val_nb)
|
|
594
|
+
_register_loss_fns("tweedie", _res_tweedie, _val_tweedie)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
# Populate _LOSS_EVAL_DISPATCH using the backend-agnostic _val_* functions.
|
|
598
|
+
# _val_* auto-detect numpy/cupy/torch via _get_xp; _evaluate_loss_numpy
|
|
599
|
+
# always passes numpy arrays, so they behave identically to the old _ps_* fns.
|
|
600
|
+
_LOSS_EVAL_DISPATCH.update({
|
|
601
|
+
"logistic": (_val_logistic, False),
|
|
602
|
+
"squared_error": (_ps_squared_error, True),
|
|
603
|
+
"poisson": (_val_poisson, False),
|
|
604
|
+
"gamma": (_val_gamma, False),
|
|
605
|
+
"inverse_gaussian": (_val_invgauss, False),
|
|
606
|
+
"negative_binomial": (_val_nb, False),
|
|
607
|
+
"tweedie": (_val_tweedie, False),
|
|
608
|
+
})
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
# Fold-batch config: Lipschitz function and intercept function per loss.
|
|
612
|
+
# Residual and val_loss are handled by the loss formula registry above.
|
|
613
|
+
_FOLD_BATCH_CONFIGS = {}
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def _logistic_intercept(y_mean, is_torch):
|
|
617
|
+
if is_torch:
|
|
618
|
+
import torch
|
|
619
|
+
y_prob = torch.clamp(y_mean, min=1e-3, max=0.999)
|
|
620
|
+
return torch.log(y_prob) - torch.log(1.0 - y_prob)
|
|
621
|
+
else:
|
|
622
|
+
import cupy as cp
|
|
623
|
+
y_prob = cp.clip(y_mean, 1e-3, 0.999)
|
|
624
|
+
return cp.log(y_prob) - cp.log(1.0 - y_prob)
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _exp_link_intercept(y_mean, is_torch):
|
|
628
|
+
"""Intercept for log-link GLMs: log(clamp(y_mean, 1e-3, 100))."""
|
|
629
|
+
if is_torch:
|
|
630
|
+
import torch
|
|
631
|
+
return torch.log(torch.clamp(y_mean, min=1e-3, max=100.0))
|
|
632
|
+
else:
|
|
633
|
+
import cupy as cp
|
|
634
|
+
return cp.log(cp.clip(y_mean, 1e-3, 100.0))
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def _register_fold_batch(loss_name, lipschitz_fn, intercept_fn):
|
|
638
|
+
_FOLD_BATCH_CONFIGS[loss_name] = {
|
|
639
|
+
"lipschitz_fn": lipschitz_fn,
|
|
640
|
+
"intercept_fn": intercept_fn,
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
_register_fold_batch("logistic", _fold_batch_lipschitz_logistic, _logistic_intercept)
|
|
645
|
+
_register_fold_batch("poisson", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
|
|
646
|
+
_register_fold_batch("gamma", _fold_batch_lipschitz_gamma, _exp_link_intercept)
|
|
647
|
+
_register_fold_batch("inverse_gaussian", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
|
|
648
|
+
_register_fold_batch("negative_binomial", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
|
|
649
|
+
_register_fold_batch("tweedie", _fold_batch_lipschitz_exp_link, _exp_link_intercept)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def _glm_sparse_cv_folds(
|
|
653
|
+
X,
|
|
654
|
+
y,
|
|
655
|
+
folds,
|
|
656
|
+
alpha_sorted,
|
|
657
|
+
penalty_name,
|
|
658
|
+
l1_ratio,
|
|
659
|
+
max_iter,
|
|
660
|
+
tol,
|
|
661
|
+
loss_name,
|
|
662
|
+
device_backend,
|
|
663
|
+
sample_weight=None,
|
|
664
|
+
loss_kwargs=None,
|
|
665
|
+
):
|
|
666
|
+
"""Unified fold-batched sparse GLM CV path for all losses and backends.
|
|
667
|
+
|
|
668
|
+
Uses direct torch/cupy API calls (no abstraction layer) for performance
|
|
669
|
+
in the FISTA hot loop.
|
|
670
|
+
"""
|
|
671
|
+
cfg = _FOLD_BATCH_CONFIGS.get(loss_name)
|
|
672
|
+
if cfg is None:
|
|
673
|
+
return None
|
|
674
|
+
|
|
675
|
+
# Resolve loss-specific parameters: user-specified kwargs override defaults
|
|
676
|
+
_lk = loss_kwargs or {}
|
|
677
|
+
from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
|
|
678
|
+
_loss_obj = _resolve_loss_name(loss_name, loss_kwargs=_lk)
|
|
679
|
+
_nb_alpha = float(_lk.get('alpha', getattr(_loss_obj, 'alpha', _NB_ALPHA_DEFAULT)))
|
|
680
|
+
_tw_power = float(_lk.get('power', getattr(_loss_obj, 'power', _TWEEDIE_POWER_DEFAULT)))
|
|
681
|
+
|
|
682
|
+
is_torch = (device_backend == "torch")
|
|
683
|
+
if is_torch:
|
|
684
|
+
if _backend_name_for_cv_device("torch") != "torch":
|
|
685
|
+
return None
|
|
686
|
+
try:
|
|
687
|
+
import torch
|
|
688
|
+
if not torch.cuda.is_available():
|
|
689
|
+
return None
|
|
690
|
+
except (ImportError, RuntimeError, OSError):
|
|
691
|
+
return None
|
|
692
|
+
else:
|
|
693
|
+
if _backend_name_for_cv_device("cuda") != "cupy":
|
|
694
|
+
return None
|
|
695
|
+
try:
|
|
696
|
+
import cupy as cp
|
|
697
|
+
if cp.cuda.runtime.getDeviceCount() <= 0:
|
|
698
|
+
return None
|
|
699
|
+
except (ImportError, RuntimeError, OSError):
|
|
700
|
+
return None
|
|
701
|
+
|
|
702
|
+
Xb = _to_backend_float64(X, device_backend)
|
|
703
|
+
yb = _to_backend_float64(y, device_backend).reshape(-1)
|
|
704
|
+
alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
|
|
705
|
+
penalty_name = str(penalty_name).lower()
|
|
706
|
+
is_enet = penalty_name in ("elasticnet", "en")
|
|
707
|
+
n_samples, n_features = Xb.shape
|
|
708
|
+
n_folds = len(folds)
|
|
709
|
+
if n_folds < 2 or alphas.size == 0:
|
|
710
|
+
return None
|
|
711
|
+
|
|
712
|
+
lipschitz_fn = cfg["lipschitz_fn"]
|
|
713
|
+
intercept_fn = cfg["intercept_fn"]
|
|
714
|
+
|
|
715
|
+
# --- Build masks and compute per-fold Lipschitz ---
|
|
716
|
+
dev = Xb.device if is_torch else None
|
|
717
|
+
train_mask = _fb_ones((n_samples, n_folds), Xb.dtype, is_torch, dev)
|
|
718
|
+
val_mask = _fb_zeros((n_samples, n_folds), Xb.dtype, is_torch, dev)
|
|
719
|
+
|
|
720
|
+
# Sample weight mask: per-fold weights (n_samples, n_folds)
|
|
721
|
+
# Deferred until after the fold loop when train_mask is finalized.
|
|
722
|
+
has_weights = sample_weight is not None
|
|
723
|
+
if has_weights:
|
|
724
|
+
sw_all = _to_backend_float64(sample_weight, device_backend).reshape(-1)
|
|
725
|
+
|
|
726
|
+
step_values = []
|
|
727
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
728
|
+
train_idx_dev = _fb_as_tensor(train_idx, is_torch, dev)
|
|
729
|
+
val_idx_dev = _fb_as_tensor(val_idx, is_torch, dev)
|
|
730
|
+
train_mask[val_idx_dev, fold_idx] = 0.0
|
|
731
|
+
val_mask[val_idx_dev, fold_idx] = 1.0
|
|
732
|
+
|
|
733
|
+
X_train = Xb[train_idx_dev]
|
|
734
|
+
y_train = yb[train_idx_dev]
|
|
735
|
+
ones = _fb_ones((X_train.shape[0], 1), Xb.dtype, is_torch, dev)
|
|
736
|
+
X_aug = _fb_cat([X_train, ones], is_torch)
|
|
737
|
+
n_train = int(X_train.shape[0])
|
|
738
|
+
# For weighted Lipschitz, pass sum(w) as effective n for normalization
|
|
739
|
+
if has_weights:
|
|
740
|
+
sw_fold = sw_all[train_idx_dev]
|
|
741
|
+
sw_col_fold = sw_fold.reshape(-1, 1)
|
|
742
|
+
_xp_sw = _get_xp(sw_col_fold)
|
|
743
|
+
Xw = X_aug * _xp_sw.sqrt(sw_col_fold)
|
|
744
|
+
w_sum_fold = float(sw_fold.sum().item()) if is_torch else float(sw_fold.sum())
|
|
745
|
+
L_loss = lipschitz_fn(Xw, y_train, max(w_sum_fold, 1.0), is_torch)
|
|
746
|
+
else:
|
|
747
|
+
L_loss = lipschitz_fn(X_aug, y_train, n_train, is_torch)
|
|
748
|
+
step_values.append(1.0 / L_loss)
|
|
749
|
+
|
|
750
|
+
# Build sw_mask now that train_mask is finalized (val rows are 0)
|
|
751
|
+
if has_weights:
|
|
752
|
+
sw_mask = sw_all.reshape(-1, 1) * train_mask
|
|
753
|
+
else:
|
|
754
|
+
sw_mask = train_mask # effectively all 1s for train, 0s for val
|
|
755
|
+
|
|
756
|
+
# --- Initialize parameters ---
|
|
757
|
+
sw_train_vec = _fb_sum(sw_mask, is_torch, axis=0, keepdims=True).reshape(1, n_folds)
|
|
758
|
+
# Guard against zero-weight folds (would cause division-by-zero)
|
|
759
|
+
if is_torch:
|
|
760
|
+
sw_train_vec = torch.clamp(sw_train_vec, min=1e-10)
|
|
761
|
+
else:
|
|
762
|
+
sw_train_vec = cp.clip(sw_train_vec, 1e-10, None)
|
|
763
|
+
n_val_vec = _fb_sum(val_mask, is_torch, axis=0, keepdims=True).reshape(1, n_folds)
|
|
764
|
+
# Guard against zero-sample validation folds (division-by-zero)
|
|
765
|
+
if is_torch:
|
|
766
|
+
n_val_vec = torch.clamp(n_val_vec, min=1.0)
|
|
767
|
+
else:
|
|
768
|
+
n_val_vec = cp.maximum(n_val_vec, 1.0)
|
|
769
|
+
y_col = yb.reshape(-1, 1)
|
|
770
|
+
# Weighted mean of y per fold
|
|
771
|
+
y_mean = _fb_sum(y_col * sw_mask, is_torch, axis=0, keepdims=True) / sw_train_vec
|
|
772
|
+
intercept = intercept_fn(y_mean, is_torch).reshape(1, n_folds)
|
|
773
|
+
coef = _fb_zeros((n_features, n_folds), Xb.dtype, is_torch, dev)
|
|
774
|
+
if is_torch:
|
|
775
|
+
import torch
|
|
776
|
+
step = torch.as_tensor(step_values, dtype=Xb.dtype, device=dev).reshape(1, n_folds)
|
|
777
|
+
else:
|
|
778
|
+
import cupy as cp
|
|
779
|
+
step = cp.asarray(step_values, dtype=Xb.dtype).reshape(1, n_folds)
|
|
780
|
+
|
|
781
|
+
tol_float = float(tol)
|
|
782
|
+
scores_path = []
|
|
783
|
+
iters_path = []
|
|
784
|
+
|
|
785
|
+
# Pre-build loss kwargs to avoid dict construction in hot loop
|
|
786
|
+
_loss_kwargs = {}
|
|
787
|
+
if loss_name == "negative_binomial":
|
|
788
|
+
_loss_kwargs["alpha"] = _nb_alpha
|
|
789
|
+
elif loss_name == "tweedie":
|
|
790
|
+
_loss_kwargs["power"] = _tw_power
|
|
791
|
+
|
|
792
|
+
# Hoist function lookups outside hot loop (avoid dict lookup per iteration)
|
|
793
|
+
_resid_fn = _LOSS_RESIDUAL_FNS[loss_name]
|
|
794
|
+
_valloss_fn = _LOSS_VALLOSS_FNS[loss_name]
|
|
795
|
+
|
|
796
|
+
# Precompute sw_val_mask/sw_val_vec once (val_mask is constant across alphas)
|
|
797
|
+
if has_weights:
|
|
798
|
+
sw_val_mask = sw_all.reshape(-1, 1) * val_mask
|
|
799
|
+
sw_val_vec = _fb_sum(sw_val_mask, is_torch, axis=0, keepdims=True).reshape(1, n_folds)
|
|
800
|
+
sw_val_vec = torch.clamp(sw_val_vec, min=1e-10) if is_torch else cp.clip(sw_val_vec, 1e-10, None)
|
|
801
|
+
|
|
802
|
+
# --- FISTA loop over alphas ---
|
|
803
|
+
# y_coef / y_intercept are the extrapolated iterates (standard FISTA notation).
|
|
804
|
+
for alpha in alphas:
|
|
805
|
+
y_coef = _fb_copy(coef, is_torch)
|
|
806
|
+
y_intercept = _fb_copy(intercept, is_torch)
|
|
807
|
+
t_k = 1.0
|
|
808
|
+
if is_torch:
|
|
809
|
+
active = torch.ones((1, n_folds), dtype=torch.bool, device=Xb.device)
|
|
810
|
+
last_iter = torch.zeros((n_folds,), dtype=torch.int64, device=Xb.device)
|
|
811
|
+
else:
|
|
812
|
+
active = cp.ones((1, n_folds), dtype=bool)
|
|
813
|
+
last_iter = cp.zeros((n_folds,), dtype=cp.int64)
|
|
814
|
+
|
|
815
|
+
for iteration in range(int(max_iter)):
|
|
816
|
+
coef_old = _fb_copy(coef, is_torch)
|
|
817
|
+
intercept_old = _fb_copy(intercept, is_torch)
|
|
818
|
+
|
|
819
|
+
eta = Xb @ y_coef + y_intercept
|
|
820
|
+
# Compute per-sample residual via loss registry.
|
|
821
|
+
# Each loss defines a backend-agnostic residual function.
|
|
822
|
+
resid = _resid_fn(eta, y_col, **_loss_kwargs) * train_mask
|
|
823
|
+
# Weighted gradient: multiply residual by sw_mask (includes train_mask)
|
|
824
|
+
# and divide by sum of weights per fold
|
|
825
|
+
grad_coef = (Xb.T @ (resid * sw_mask)) / sw_train_vec
|
|
826
|
+
grad_intercept = _fb_sum(resid * sw_mask, is_torch, axis=0, keepdims=True) / sw_train_vec
|
|
827
|
+
|
|
828
|
+
w = y_coef - step * grad_coef
|
|
829
|
+
if is_enet:
|
|
830
|
+
thresh = float(alpha) * float(l1_ratio) * step
|
|
831
|
+
denom = 1.0 + float(alpha) * (1.0 - float(l1_ratio)) * step
|
|
832
|
+
else:
|
|
833
|
+
thresh = float(alpha) * step
|
|
834
|
+
denom = 1.0
|
|
835
|
+
coef_new = _soft_threshold(w, thresh) / denom
|
|
836
|
+
intercept_new = y_intercept - step * grad_intercept
|
|
837
|
+
|
|
838
|
+
coef = torch.where(active, coef_new, coef) if is_torch else cp.where(active, coef_new, coef)
|
|
839
|
+
intercept = torch.where(active, intercept_new, intercept) if is_torch else cp.where(active, intercept_new, intercept)
|
|
840
|
+
|
|
841
|
+
beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
842
|
+
y_coef_new = coef + beta * (coef - coef_old)
|
|
843
|
+
y_intercept_new = intercept + beta * (intercept - intercept_old)
|
|
844
|
+
y_coef = torch.where(active, y_coef_new, coef) if is_torch else cp.where(active, y_coef_new, coef)
|
|
845
|
+
y_intercept = torch.where(active, y_intercept_new, intercept) if is_torch else cp.where(active, y_intercept_new, intercept)
|
|
846
|
+
if is_torch:
|
|
847
|
+
last_iter = torch.where(active.reshape(-1), torch.full_like(last_iter, iteration + 1), last_iter)
|
|
848
|
+
else:
|
|
849
|
+
last_iter = cp.where(active.reshape(-1), cp.full_like(last_iter, iteration + 1), last_iter)
|
|
850
|
+
|
|
851
|
+
# Check convergence: every iteration for first 20, then every 50
|
|
852
|
+
if iteration < 20 or iteration % 50 == 0:
|
|
853
|
+
if is_torch:
|
|
854
|
+
delta = torch.sum(torch.abs(coef - coef_old), dim=0, keepdim=True) + torch.abs(intercept - intercept_old)
|
|
855
|
+
else:
|
|
856
|
+
delta = cp.sum(cp.abs(coef - coef_old), axis=0, keepdims=True) + cp.abs(intercept - intercept_old)
|
|
857
|
+
active = active & (delta >= tol_float)
|
|
858
|
+
_any_active = torch.any(active) if is_torch else cp.any(active)
|
|
859
|
+
if not _to_float_scalar(_any_active):
|
|
860
|
+
break
|
|
861
|
+
|
|
862
|
+
# Validation loss via loss registry (single call, backend-agnostic)
|
|
863
|
+
eta_val = Xb @ coef + intercept
|
|
864
|
+
val_loss = _valloss_fn(eta_val, y_col, **_loss_kwargs) * val_mask
|
|
865
|
+
if has_weights:
|
|
866
|
+
scores_path.append(_fb_sum(val_loss * sw_val_mask, is_torch, axis=0, keepdims=True).reshape(-1) / sw_val_vec.reshape(-1))
|
|
867
|
+
else:
|
|
868
|
+
scores_path.append(_fb_sum(val_loss, is_torch, axis=0, keepdims=True).reshape(-1) / n_val_vec.reshape(-1))
|
|
869
|
+
iters_path.append(last_iter)
|
|
870
|
+
|
|
871
|
+
scores = _fb_stack(scores_path, is_torch)
|
|
872
|
+
n_iter = _fb_stack(iters_path, is_torch)
|
|
873
|
+
return {
|
|
874
|
+
"scores": np.asarray(_to_numpy(scores), dtype=np.float64),
|
|
875
|
+
"n_iter": np.asarray(_to_numpy(n_iter), dtype=np.int64),
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def _scalar_to_float(x):
|
|
880
|
+
return float(_to_numpy(x))
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def _logistic_sparse_cv_path(
|
|
884
|
+
X_train,
|
|
885
|
+
y_train,
|
|
886
|
+
alpha_sorted,
|
|
887
|
+
penalty_name,
|
|
888
|
+
l1_ratio,
|
|
889
|
+
max_iter,
|
|
890
|
+
tol,
|
|
891
|
+
device,
|
|
892
|
+
X_val=None,
|
|
893
|
+
y_val=None,
|
|
894
|
+
sample_weight=None,
|
|
895
|
+
val_sample_weight=None,
|
|
896
|
+
return_path=True,
|
|
897
|
+
):
|
|
898
|
+
"""Fit a logistic sparse alpha path and optionally score validation loss.
|
|
899
|
+
|
|
900
|
+
This CV-only path uses a fixed global Lipschitz bound and sparse proximal
|
|
901
|
+
updates, avoiding per-iteration Armijo/objective synchronizations.
|
|
902
|
+
|
|
903
|
+
Parameters
|
|
904
|
+
----------
|
|
905
|
+
val_sample_weight : array-like, optional
|
|
906
|
+
Per-sample weights for validation scoring. When provided, validation
|
|
907
|
+
loss is computed as weighted mean.
|
|
908
|
+
"""
|
|
909
|
+
if not _is_uniform_weight(sample_weight):
|
|
910
|
+
warnings.warn(
|
|
911
|
+
"_logistic_sparse_cv_path: non-uniform sample_weight not supported, "
|
|
912
|
+
"falling back to general CV path.",
|
|
913
|
+
RuntimeWarning,
|
|
914
|
+
stacklevel=2,
|
|
915
|
+
)
|
|
916
|
+
return None
|
|
917
|
+
|
|
918
|
+
backend = _backend_name_for_cv_device(device)
|
|
919
|
+
Xb = _to_backend_float64(X_train, backend)
|
|
920
|
+
yb = _to_backend_float64(y_train, backend).reshape(-1)
|
|
921
|
+
alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
|
|
922
|
+
n_samples, n_features = Xb.shape
|
|
923
|
+
|
|
924
|
+
from statgpu.backends._utils import _get_xp, xp_ones
|
|
925
|
+
xp = _get_xp(backend)
|
|
926
|
+
ones = xp_ones((n_samples, 1), dtype=Xb.dtype, xp=xp, ref_arr=Xb)
|
|
927
|
+
X_aug = xp.concatenate([Xb, ones], axis=1)
|
|
928
|
+
y_mean = _to_float_scalar(xp.mean(yb))
|
|
929
|
+
coef = _zeros(n_features, backend, ref_tensor=Xb)
|
|
930
|
+
_int_val = np.log(np.clip(y_mean, 1e-3, 1.0 - 1e-3) / (1.0 - np.clip(y_mean, 1e-3, 1.0 - 1e-3)))
|
|
931
|
+
from statgpu.backends._array_ops import _scalar_tensor
|
|
932
|
+
intercept = _scalar_tensor(_int_val, Xb)
|
|
933
|
+
|
|
934
|
+
eig_max = _max_eigval_power(X_aug.T @ X_aug)
|
|
935
|
+
L_loss = max(eig_max / (4.0 * max(int(n_samples), 1)), 1e-12)
|
|
936
|
+
step = 1.0 / L_loss
|
|
937
|
+
conv_interval = _CONV_INTERVAL_CV_NUMPY if backend == "numpy" else _CONV_INTERVAL_CV_FOLD
|
|
938
|
+
penalty_name = str(penalty_name).lower()
|
|
939
|
+
is_enet = penalty_name in ("elasticnet", "en")
|
|
940
|
+
|
|
941
|
+
if X_val is not None and y_val is not None:
|
|
942
|
+
Xv = _to_backend_float64(X_val, backend)
|
|
943
|
+
yv = _to_backend_float64(y_val, backend).reshape(-1)
|
|
944
|
+
swv = _to_backend_float64(val_sample_weight, backend).reshape(-1) if val_sample_weight is not None else None
|
|
945
|
+
else:
|
|
946
|
+
Xv = yv = swv = None
|
|
947
|
+
|
|
948
|
+
scores = []
|
|
949
|
+
score_coef_path = []
|
|
950
|
+
score_intercept_path = []
|
|
951
|
+
coef_path = []
|
|
952
|
+
intercept_path = []
|
|
953
|
+
iters = []
|
|
954
|
+
|
|
955
|
+
for alpha in alphas:
|
|
956
|
+
y_coef = _copy_arr(coef)
|
|
957
|
+
y_intercept = _copy_arr(intercept) if hasattr(intercept, 'clone') else float(intercept)
|
|
958
|
+
t_k = 1.0
|
|
959
|
+
last_iter = 0
|
|
960
|
+
for iteration in range(int(max_iter)):
|
|
961
|
+
coef_old = _copy_arr(coef)
|
|
962
|
+
intercept_old = _copy_arr(intercept) if hasattr(intercept, 'clone') else float(intercept)
|
|
963
|
+
|
|
964
|
+
eta = Xb @ y_coef + y_intercept
|
|
965
|
+
prob = _sigmoid(eta)
|
|
966
|
+
resid = prob - yb
|
|
967
|
+
grad_coef = Xb.T @ resid / n_samples
|
|
968
|
+
grad_intercept = xp.mean(resid)
|
|
969
|
+
|
|
970
|
+
w = y_coef - step * grad_coef
|
|
971
|
+
if is_enet:
|
|
972
|
+
thresh = float(alpha) * float(l1_ratio) * step
|
|
973
|
+
denom = 1.0 + float(alpha) * (1.0 - float(l1_ratio)) * step
|
|
974
|
+
else:
|
|
975
|
+
thresh = float(alpha) * step
|
|
976
|
+
denom = 1.0
|
|
977
|
+
|
|
978
|
+
coef = _soft_threshold(w, thresh) / denom
|
|
979
|
+
intercept = y_intercept - step * grad_intercept
|
|
980
|
+
|
|
981
|
+
beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
982
|
+
y_coef = coef + beta * (coef - coef_old)
|
|
983
|
+
y_intercept = intercept + beta * (intercept - intercept_old)
|
|
984
|
+
last_iter = iteration + 1
|
|
985
|
+
|
|
986
|
+
if iteration < 20 or iteration % conv_interval == 0:
|
|
987
|
+
delta = xp.sum(xp.abs(coef - coef_old)) + xp.abs(intercept - intercept_old)
|
|
988
|
+
converged = _to_float_scalar(delta) < tol
|
|
989
|
+
if converged:
|
|
990
|
+
break
|
|
991
|
+
|
|
992
|
+
if Xv is not None:
|
|
993
|
+
if backend == "torch":
|
|
994
|
+
score_coef_path.append(coef.clone())
|
|
995
|
+
score_intercept_path.append(intercept.clone())
|
|
996
|
+
else:
|
|
997
|
+
eta_v = Xv @ coef + intercept
|
|
998
|
+
per_sample = -yv * eta_v + _softplus(eta_v)
|
|
999
|
+
if swv is not None:
|
|
1000
|
+
sw_sum = xp.sum(swv)
|
|
1001
|
+
val_loss = xp.sum(swv * per_sample) / sw_sum if float(sw_sum) > 0 else xp.mean(per_sample)
|
|
1002
|
+
else:
|
|
1003
|
+
val_loss = xp.mean(per_sample)
|
|
1004
|
+
score_coef_path.append(val_loss)
|
|
1005
|
+
if return_path:
|
|
1006
|
+
coef_path.append(np.asarray(_to_numpy(coef), dtype=np.float64).copy())
|
|
1007
|
+
intercept_path.append(_scalar_to_float(intercept))
|
|
1008
|
+
iters.append(last_iter)
|
|
1009
|
+
|
|
1010
|
+
# Torch benefits from one alpha-path GEMM for validation. For NumPy/CuPy
|
|
1011
|
+
# at these small alpha-grid widths, per-alpha GEMV is consistently steadier.
|
|
1012
|
+
if score_coef_path and Xv is not None:
|
|
1013
|
+
if backend == "torch":
|
|
1014
|
+
import torch
|
|
1015
|
+
coef_mat = torch.stack(score_coef_path, dim=1)
|
|
1016
|
+
intercept_vec = torch.stack(score_intercept_path).reshape(1, -1)
|
|
1017
|
+
eta_v = Xv @ coef_mat + intercept_vec
|
|
1018
|
+
per_sample = -yv.reshape(-1, 1) * eta_v + _softplus(eta_v)
|
|
1019
|
+
if swv is not None:
|
|
1020
|
+
sw_sum = swv.sum()
|
|
1021
|
+
if sw_sum > 0:
|
|
1022
|
+
scores_tensor = (swv.reshape(-1, 1) * per_sample).sum(dim=0) / sw_sum
|
|
1023
|
+
else:
|
|
1024
|
+
scores_tensor = per_sample.mean(dim=0)
|
|
1025
|
+
else:
|
|
1026
|
+
scores_tensor = per_sample.mean(dim=0)
|
|
1027
|
+
scores = _to_numpy(scores_tensor).tolist()
|
|
1028
|
+
else:
|
|
1029
|
+
# cupy/numpy path accumulates Python floats (scalar val_loss)
|
|
1030
|
+
scores = [float(s) for s in score_coef_path]
|
|
1031
|
+
|
|
1032
|
+
out = {
|
|
1033
|
+
"scores": np.asarray(scores, dtype=np.float64) if scores else None,
|
|
1034
|
+
"n_iter": np.asarray(iters, dtype=np.int64),
|
|
1035
|
+
}
|
|
1036
|
+
if return_path:
|
|
1037
|
+
out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
|
|
1038
|
+
out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
|
|
1039
|
+
return out
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
# (Old per-loss fold-batched functions removed — replaced by _glm_sparse_cv_folds)
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
def _squared_error_sparse_cv_path(
|
|
1046
|
+
X_train,
|
|
1047
|
+
y_train,
|
|
1048
|
+
alpha_sorted,
|
|
1049
|
+
penalty_name,
|
|
1050
|
+
l1_ratio,
|
|
1051
|
+
max_iter,
|
|
1052
|
+
tol,
|
|
1053
|
+
device,
|
|
1054
|
+
X_val=None,
|
|
1055
|
+
y_val=None,
|
|
1056
|
+
sample_weight=None,
|
|
1057
|
+
val_sample_weight=None,
|
|
1058
|
+
return_path=True,
|
|
1059
|
+
):
|
|
1060
|
+
"""Fit a squared-error sparse alpha path with centered data.
|
|
1061
|
+
|
|
1062
|
+
This is used by CV for l1/elasticnet penalties. It solves all alphas in one
|
|
1063
|
+
fold using a single Gram matrix and warm-started FISTA path.
|
|
1064
|
+
"""
|
|
1065
|
+
if not _is_uniform_weight(sample_weight):
|
|
1066
|
+
warnings.warn(
|
|
1067
|
+
"_squared_error_sparse_cv_path: non-uniform sample_weight not supported, "
|
|
1068
|
+
"falling back to general CV path.",
|
|
1069
|
+
RuntimeWarning,
|
|
1070
|
+
stacklevel=2,
|
|
1071
|
+
)
|
|
1072
|
+
return None
|
|
1073
|
+
|
|
1074
|
+
backend = _backend_name_for_cv_device(device)
|
|
1075
|
+
Xb = _to_backend_float64(X_train, backend)
|
|
1076
|
+
yb = _to_backend_float64(y_train, backend).reshape(-1)
|
|
1077
|
+
alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
|
|
1078
|
+
n_samples, n_features = Xb.shape
|
|
1079
|
+
penalty_name = str(penalty_name).lower()
|
|
1080
|
+
is_enet = penalty_name in ("elasticnet", "en")
|
|
1081
|
+
|
|
1082
|
+
from statgpu.backends._utils import _get_xp
|
|
1083
|
+
xp = _get_xp(backend)
|
|
1084
|
+
X_mean = xp.mean(Xb, axis=0)
|
|
1085
|
+
y_mean = xp.mean(yb)
|
|
1086
|
+
Xc = Xb - X_mean
|
|
1087
|
+
yc = yb - y_mean
|
|
1088
|
+
XtX = Xc.T @ Xc
|
|
1089
|
+
Xty = Xc.T @ yc
|
|
1090
|
+
coef = _zeros(n_features, backend, ref_tensor=Xb)
|
|
1091
|
+
|
|
1092
|
+
eig_max = _max_eigval_power(XtX)
|
|
1093
|
+
L = max(eig_max / max(int(n_samples), 1), 1e-12)
|
|
1094
|
+
step = 1.0 / L
|
|
1095
|
+
conv_interval = _CONV_INTERVAL_CV_NUMPY if backend == "numpy" else _CONV_INTERVAL_CV_PATH
|
|
1096
|
+
|
|
1097
|
+
if X_val is not None and y_val is not None:
|
|
1098
|
+
Xv = _to_backend_float64(X_val, backend)
|
|
1099
|
+
yv = _to_backend_float64(y_val, backend).reshape(-1)
|
|
1100
|
+
Xv_centered = Xv - X_mean
|
|
1101
|
+
swv = _to_backend_float64(val_sample_weight, backend).reshape(-1) if val_sample_weight is not None else None
|
|
1102
|
+
else:
|
|
1103
|
+
Xv = yv = Xv_centered = swv = None
|
|
1104
|
+
|
|
1105
|
+
if backend in ("torch", "cupy") and not return_path and Xv_centered is not None:
|
|
1106
|
+
n_alpha = int(alphas.size)
|
|
1107
|
+
from statgpu.backends._utils import xp_asarray
|
|
1108
|
+
alpha_vec = xp_asarray(alphas, dtype=Xb.dtype, xp=xp, ref_arr=Xb).reshape(1, -1)
|
|
1109
|
+
coef_mat = _xp_zeros((n_features, n_alpha), Xb.dtype, Xb)
|
|
1110
|
+
y_mat = _copy_arr(coef_mat)
|
|
1111
|
+
|
|
1112
|
+
t_k = 1.0
|
|
1113
|
+
last_iter = 0
|
|
1114
|
+
x_ty = Xty.reshape(-1, 1)
|
|
1115
|
+
for iteration in range(int(max_iter)):
|
|
1116
|
+
coef_old = _copy_arr(coef_mat)
|
|
1117
|
+
grad = (XtX @ y_mat - x_ty) / n_samples
|
|
1118
|
+
w = y_mat - step * grad
|
|
1119
|
+
if is_enet:
|
|
1120
|
+
thresh = alpha_vec * float(l1_ratio) * step
|
|
1121
|
+
denom = 1.0 + alpha_vec * (1.0 - float(l1_ratio)) * step
|
|
1122
|
+
else:
|
|
1123
|
+
thresh = alpha_vec * step
|
|
1124
|
+
denom = 1.0
|
|
1125
|
+
|
|
1126
|
+
coef_mat = _soft_threshold(w, thresh) / denom
|
|
1127
|
+
|
|
1128
|
+
beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
1129
|
+
y_mat = coef_mat + beta * (coef_mat - coef_old)
|
|
1130
|
+
last_iter = iteration + 1
|
|
1131
|
+
|
|
1132
|
+
if iteration < 20 or iteration % conv_interval == 0:
|
|
1133
|
+
delta = xp.sum(xp.abs(coef_mat - coef_old), axis=0)
|
|
1134
|
+
if _to_float_scalar(xp.all(delta < tol)):
|
|
1135
|
+
break
|
|
1136
|
+
|
|
1137
|
+
pred = Xv_centered @ coef_mat + y_mean
|
|
1138
|
+
sq_err = (yv.reshape(-1, 1) - pred) ** 2
|
|
1139
|
+
if swv is not None:
|
|
1140
|
+
sw_col = swv.reshape(-1, 1)
|
|
1141
|
+
if backend == "torch":
|
|
1142
|
+
scores_dev = (sw_col * sq_err).sum(dim=0) / swv.sum()
|
|
1143
|
+
else:
|
|
1144
|
+
scores_dev = (sw_col * sq_err).sum(axis=0) / swv.sum()
|
|
1145
|
+
else:
|
|
1146
|
+
if backend == "torch":
|
|
1147
|
+
scores_dev = sq_err.mean(dim=0)
|
|
1148
|
+
else:
|
|
1149
|
+
scores_dev = sq_err.mean(axis=0)
|
|
1150
|
+
return {
|
|
1151
|
+
"scores": np.asarray(_to_numpy(scores_dev), dtype=np.float64),
|
|
1152
|
+
"n_iter": np.full(n_alpha, int(last_iter), dtype=np.int64),
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
scores = []
|
|
1156
|
+
scores_dev = [] # accumulate on device, sync once at end
|
|
1157
|
+
coef_path = []
|
|
1158
|
+
intercept_path = []
|
|
1159
|
+
iters = []
|
|
1160
|
+
|
|
1161
|
+
for alpha in alphas:
|
|
1162
|
+
y_k = _copy_arr(coef)
|
|
1163
|
+
t_k = 1.0
|
|
1164
|
+
last_iter = 0
|
|
1165
|
+
for iteration in range(int(max_iter)):
|
|
1166
|
+
coef_old = _copy_arr(coef)
|
|
1167
|
+
grad = (XtX @ y_k - Xty) / n_samples
|
|
1168
|
+
w = y_k - step * grad
|
|
1169
|
+
if is_enet:
|
|
1170
|
+
thresh = float(alpha) * float(l1_ratio) * step
|
|
1171
|
+
denom = 1.0 + float(alpha) * (1.0 - float(l1_ratio)) * step
|
|
1172
|
+
else:
|
|
1173
|
+
thresh = float(alpha) * step
|
|
1174
|
+
denom = 1.0
|
|
1175
|
+
|
|
1176
|
+
coef = _soft_threshold(w, thresh) / denom
|
|
1177
|
+
|
|
1178
|
+
beta, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
1179
|
+
y_k = coef + beta * (coef - coef_old)
|
|
1180
|
+
last_iter = iteration + 1
|
|
1181
|
+
|
|
1182
|
+
if backend == "numpy" or int(n_features) <= 128:
|
|
1183
|
+
check_convergence = iteration < 20 or iteration % conv_interval == 0
|
|
1184
|
+
else:
|
|
1185
|
+
check_convergence = iteration % conv_interval == 0
|
|
1186
|
+
if check_convergence:
|
|
1187
|
+
delta = xp.sum(xp.abs(coef - coef_old))
|
|
1188
|
+
if _to_float_scalar(delta) < tol:
|
|
1189
|
+
break
|
|
1190
|
+
|
|
1191
|
+
intercept = y_mean - X_mean @ coef
|
|
1192
|
+
if Xv_centered is not None:
|
|
1193
|
+
pred = Xv_centered @ coef + y_mean
|
|
1194
|
+
sq_err = (yv - pred) ** 2
|
|
1195
|
+
if swv is not None:
|
|
1196
|
+
mse = xp.sum(swv * sq_err) / xp.sum(swv)
|
|
1197
|
+
else:
|
|
1198
|
+
mse = xp.mean(sq_err)
|
|
1199
|
+
scores_dev.append(mse) # keep on device
|
|
1200
|
+
if return_path:
|
|
1201
|
+
coef_path.append(np.asarray(_to_numpy(coef), dtype=np.float64).copy())
|
|
1202
|
+
intercept_path.append(_scalar_to_float(intercept))
|
|
1203
|
+
iters.append(last_iter)
|
|
1204
|
+
|
|
1205
|
+
# Batch sync validation scores from device.
|
|
1206
|
+
if scores_dev:
|
|
1207
|
+
if backend == "torch":
|
|
1208
|
+
import torch
|
|
1209
|
+
scores_tensor = torch.stack(scores_dev)
|
|
1210
|
+
scores = _to_numpy(scores_tensor).tolist()
|
|
1211
|
+
elif backend == "cupy":
|
|
1212
|
+
import cupy as cp
|
|
1213
|
+
scores_arr = cp.stack(scores_dev)
|
|
1214
|
+
scores = _to_numpy(scores_arr).tolist()
|
|
1215
|
+
else:
|
|
1216
|
+
scores = [float(s) for s in scores_dev]
|
|
1217
|
+
|
|
1218
|
+
out = {
|
|
1219
|
+
"scores": np.asarray(scores, dtype=np.float64) if scores else None,
|
|
1220
|
+
"n_iter": np.asarray(iters, dtype=np.int64),
|
|
1221
|
+
}
|
|
1222
|
+
if return_path:
|
|
1223
|
+
out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
|
|
1224
|
+
out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
|
|
1225
|
+
return out
|
|
1226
|
+
|
|
1227
|
+
|
|
1228
|
+
# Intercept clipping bound: exp(15) ≈ 3.3M, prevents overflow in link
|
|
1229
|
+
# functions while allowing a wide range of intercept values.
|
|
1230
|
+
from statgpu.cross_validation._base import INTERCEPT_CLIP_BOUND as _INTERCEPT_CLIP_BOUND
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
class _FeatureOnlySparsePenalty:
|
|
1234
|
+
"""Wrap a sparse penalty so the final intercept coefficient is unpenalized."""
|
|
1235
|
+
|
|
1236
|
+
def __init__(self, base_penalty, n_features, backend):
|
|
1237
|
+
self.base_penalty = base_penalty
|
|
1238
|
+
self.n_features = int(n_features)
|
|
1239
|
+
self.backend = backend
|
|
1240
|
+
|
|
1241
|
+
@property
|
|
1242
|
+
def name(self):
|
|
1243
|
+
return getattr(self.base_penalty, "name", "")
|
|
1244
|
+
|
|
1245
|
+
@property
|
|
1246
|
+
def alpha(self):
|
|
1247
|
+
return float(getattr(self.base_penalty, "alpha", 0.0))
|
|
1248
|
+
|
|
1249
|
+
@property
|
|
1250
|
+
def l1_ratio(self):
|
|
1251
|
+
return float(getattr(self.base_penalty, "l1_ratio", 1.0))
|
|
1252
|
+
|
|
1253
|
+
def value(self, coef):
|
|
1254
|
+
return self.base_penalty.value(coef[: self.n_features])
|
|
1255
|
+
|
|
1256
|
+
def proximal(self, w, step, backend=None):
|
|
1257
|
+
from statgpu.backends._array_ops import _xp, _clip, _xp_zeros
|
|
1258
|
+
backend = backend or self.backend
|
|
1259
|
+
xp = _xp(w)
|
|
1260
|
+
w_feat = w[: self.n_features]
|
|
1261
|
+
result_feat = self.base_penalty.proximal(w_feat, step, backend=backend)
|
|
1262
|
+
result = _xp_zeros(w.shape, w.dtype, w)
|
|
1263
|
+
result[: self.n_features] = result_feat
|
|
1264
|
+
result[self.n_features] = _clip(w[self.n_features], -_INTERCEPT_CLIP_BOUND, _INTERCEPT_CLIP_BOUND)
|
|
1265
|
+
return result
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def _glm_sparse_cv_path(
|
|
1269
|
+
loss_name,
|
|
1270
|
+
X_train,
|
|
1271
|
+
y_train,
|
|
1272
|
+
alpha_sorted,
|
|
1273
|
+
penalty_name,
|
|
1274
|
+
l1_ratio,
|
|
1275
|
+
max_iter,
|
|
1276
|
+
tol,
|
|
1277
|
+
device,
|
|
1278
|
+
X_val=None,
|
|
1279
|
+
y_val=None,
|
|
1280
|
+
sample_weight=None,
|
|
1281
|
+
val_sample_weight=None,
|
|
1282
|
+
return_path=False,
|
|
1283
|
+
solver_name="fista",
|
|
1284
|
+
cv_mode=True,
|
|
1285
|
+
loss_kwargs=None,
|
|
1286
|
+
):
|
|
1287
|
+
"""Warm-started sparse GLM alpha path for CV.
|
|
1288
|
+
|
|
1289
|
+
The helper is intentionally private: it reuses the production loss,
|
|
1290
|
+
penalty, and FISTA solver while avoiding estimator reconstruction and
|
|
1291
|
+
repeated host/device conversions inside a fold.
|
|
1292
|
+
|
|
1293
|
+
When ``val_sample_weight`` is provided, validation loss is computed as
|
|
1294
|
+
a weighted mean instead of a simple mean.
|
|
1295
|
+
"""
|
|
1296
|
+
loss_name = str(loss_name).lower()
|
|
1297
|
+
penalty_name = str(penalty_name).lower()
|
|
1298
|
+
# Allow any loss registered in the formula registry
|
|
1299
|
+
if loss_name not in _LOSS_RESIDUAL_FNS:
|
|
1300
|
+
return None
|
|
1301
|
+
if penalty_name not in ("l1", "elasticnet", "en"):
|
|
1302
|
+
return None
|
|
1303
|
+
if not _is_uniform_weight(sample_weight):
|
|
1304
|
+
warnings.warn(
|
|
1305
|
+
"_glm_sparse_cv_path: non-uniform sample_weight not supported, "
|
|
1306
|
+
"falling back to general CV path.",
|
|
1307
|
+
RuntimeWarning,
|
|
1308
|
+
stacklevel=2,
|
|
1309
|
+
)
|
|
1310
|
+
return None
|
|
1311
|
+
|
|
1312
|
+
from statgpu.solvers import fista_solver, fista_bb_solver
|
|
1313
|
+
from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
|
|
1314
|
+
from statgpu.penalties import get_penalty
|
|
1315
|
+
|
|
1316
|
+
backend = _backend_name_for_cv_device(device)
|
|
1317
|
+
from statgpu.backends._utils import _get_xp
|
|
1318
|
+
xp = _get_xp(backend)
|
|
1319
|
+
Xb = _to_backend_float64(X_train, backend)
|
|
1320
|
+
yb = _to_backend_float64(y_train, backend).reshape(-1)
|
|
1321
|
+
alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
|
|
1322
|
+
n_samples, n_features = Xb.shape
|
|
1323
|
+
|
|
1324
|
+
from statgpu.backends._utils import xp_ones as _xp_ones_fn
|
|
1325
|
+
_ones = _xp_ones_fn((n_samples, 1), dtype=Xb.dtype, xp=xp, ref_arr=Xb)
|
|
1326
|
+
X_work = xp.concatenate([Xb, _ones], axis=1)
|
|
1327
|
+
|
|
1328
|
+
if X_val is not None and y_val is not None:
|
|
1329
|
+
Xv = _to_backend_float64(X_val, backend)
|
|
1330
|
+
yv = _to_backend_float64(y_val, backend).reshape(-1)
|
|
1331
|
+
n_val = Xv.shape[0]
|
|
1332
|
+
_ones_v = _xp_ones_fn((n_val, 1), dtype=Xv.dtype, xp=xp, ref_arr=Xv)
|
|
1333
|
+
X_val_work = xp.concatenate([Xv, _ones_v], axis=1)
|
|
1334
|
+
else:
|
|
1335
|
+
X_val_work = yv = swv = None
|
|
1336
|
+
|
|
1337
|
+
if X_val is not None and y_val is not None and val_sample_weight is not None:
|
|
1338
|
+
swv = _to_backend_float64(val_sample_weight, backend).reshape(-1)
|
|
1339
|
+
else:
|
|
1340
|
+
swv = None
|
|
1341
|
+
|
|
1342
|
+
sw_fit = (
|
|
1343
|
+
_to_backend_float64(sample_weight, backend)
|
|
1344
|
+
if sample_weight is not None
|
|
1345
|
+
else None
|
|
1346
|
+
)
|
|
1347
|
+
loss_fn = _resolve_loss_name(loss_name, loss_kwargs=loss_kwargs)
|
|
1348
|
+
if penalty_name in ("elasticnet", "en"):
|
|
1349
|
+
base_penalty = get_penalty("elasticnet", alpha=float(alphas[0]), l1_ratio=float(l1_ratio))
|
|
1350
|
+
else:
|
|
1351
|
+
base_penalty = get_penalty("l1", alpha=float(alphas[0]))
|
|
1352
|
+
penalty = _FeatureOnlySparsePenalty(base_penalty, n_features, backend)
|
|
1353
|
+
|
|
1354
|
+
lipschitz_L = None
|
|
1355
|
+
if not getattr(loss_fn, "_lipschitz_at_init", False):
|
|
1356
|
+
try:
|
|
1357
|
+
zero_lip = _zeros(n_features + 1, backend, ref_tensor=X_work)
|
|
1358
|
+
lipschitz_L = float(_to_numpy(loss_fn.lipschitz(X_work, zero_lip, y=yb)))
|
|
1359
|
+
if not np.isfinite(lipschitz_L) or lipschitz_L <= 0.0:
|
|
1360
|
+
lipschitz_L = None
|
|
1361
|
+
except Exception:
|
|
1362
|
+
lipschitz_L = None
|
|
1363
|
+
|
|
1364
|
+
scores = []
|
|
1365
|
+
score_params_path = []
|
|
1366
|
+
coef_path = []
|
|
1367
|
+
intercept_path = []
|
|
1368
|
+
iters = []
|
|
1369
|
+
if backend == "torch":
|
|
1370
|
+
import torch
|
|
1371
|
+
y_mean = max(float(torch.mean(yb).item()), 1e-3)
|
|
1372
|
+
elif backend == "cupy":
|
|
1373
|
+
import cupy as cp
|
|
1374
|
+
y_mean = max(float(cp.mean(yb)), 1e-3)
|
|
1375
|
+
else:
|
|
1376
|
+
y_mean = max(float(np.mean(yb)), 1e-3)
|
|
1377
|
+
# Use the correct link-function inverse for intercept initialization:
|
|
1378
|
+
# logistic -> logit link: log(y_mean / (1 - y_mean))
|
|
1379
|
+
# poisson/gamma/tweedie/nb -> log link: log(y_mean)
|
|
1380
|
+
if loss_name == "logistic":
|
|
1381
|
+
y_mean_clipped = np.clip(y_mean, 1e-7, 1.0 - 1e-7)
|
|
1382
|
+
init_intercept = np.log(y_mean_clipped / (1.0 - y_mean_clipped))
|
|
1383
|
+
else:
|
|
1384
|
+
init_intercept = np.log(y_mean)
|
|
1385
|
+
init = _zeros(n_features + 1, backend, ref_tensor=X_work)
|
|
1386
|
+
init[-1] = init_intercept
|
|
1387
|
+
solver_name = str(solver_name).lower()
|
|
1388
|
+
solver_fn = fista_bb_solver if solver_name == "fista_bb" else fista_solver
|
|
1389
|
+
for alpha in alphas:
|
|
1390
|
+
base_penalty.alpha = float(alpha)
|
|
1391
|
+
solver_kwargs = {
|
|
1392
|
+
"max_iter": int(max_iter),
|
|
1393
|
+
"tol": tol,
|
|
1394
|
+
"init_coef": init,
|
|
1395
|
+
"sample_weight": sw_fit,
|
|
1396
|
+
}
|
|
1397
|
+
if lipschitz_L is not None:
|
|
1398
|
+
solver_kwargs["lipschitz_L"] = lipschitz_L
|
|
1399
|
+
if solver_fn is fista_solver or solver_name == "fista_bb":
|
|
1400
|
+
solver_kwargs["cv_mode"] = bool(cv_mode)
|
|
1401
|
+
params, n_iter = solver_fn(
|
|
1402
|
+
loss_fn,
|
|
1403
|
+
penalty,
|
|
1404
|
+
X_work,
|
|
1405
|
+
yb,
|
|
1406
|
+
**solver_kwargs,
|
|
1407
|
+
)
|
|
1408
|
+
init = params
|
|
1409
|
+
if X_val_work is not None:
|
|
1410
|
+
if backend == "torch":
|
|
1411
|
+
score_params_path.append(params.clone())
|
|
1412
|
+
elif backend == "cupy":
|
|
1413
|
+
score_params_path.append(params.copy())
|
|
1414
|
+
else:
|
|
1415
|
+
# NumPy path: compute validation loss
|
|
1416
|
+
if swv is not None:
|
|
1417
|
+
# Weighted loss path
|
|
1418
|
+
yv_np = np.asarray(_to_numpy(yv), dtype=np.float64).ravel()
|
|
1419
|
+
sw_np = np.asarray(_to_numpy(swv), dtype=np.float64).ravel()
|
|
1420
|
+
Xv_np = np.asarray(_to_numpy(Xv), dtype=np.float64) if Xv is not None else None
|
|
1421
|
+
params_np = np.asarray(_to_numpy(params), dtype=np.float64).ravel()
|
|
1422
|
+
val = _evaluate_loss_numpy(loss_name, loss_fn,
|
|
1423
|
+
Xv_np, yv_np,
|
|
1424
|
+
params_np[:n_features],
|
|
1425
|
+
float(params_np[n_features]),
|
|
1426
|
+
True, sample_weight=sw_np)
|
|
1427
|
+
else:
|
|
1428
|
+
val = float(loss_fn.value(X_val_work, yv, params))
|
|
1429
|
+
score_params_path.append(val)
|
|
1430
|
+
if return_path:
|
|
1431
|
+
params_np = np.asarray(_to_numpy(params), dtype=np.float64).ravel()
|
|
1432
|
+
coef_path.append(params_np[:n_features].copy())
|
|
1433
|
+
intercept_path.append(float(params_np[n_features]))
|
|
1434
|
+
iters.append(int(n_iter))
|
|
1435
|
+
|
|
1436
|
+
# GPU backends: compute per-sample validation loss via registry,
|
|
1437
|
+
# then aggregate across samples (weighted or unweighted).
|
|
1438
|
+
# The registry functions work with any shape (1D or 2D batched eta).
|
|
1439
|
+
if score_params_path:
|
|
1440
|
+
_loss_params = {}
|
|
1441
|
+
if loss_name == "negative_binomial":
|
|
1442
|
+
_loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
|
|
1443
|
+
elif loss_name == "tweedie":
|
|
1444
|
+
_loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
|
|
1445
|
+
|
|
1446
|
+
if backend in ("torch", "cupy"):
|
|
1447
|
+
params_mat = xp.stack(score_params_path, axis=1)
|
|
1448
|
+
eta = X_val_work @ params_mat # (n_val, n_alphas)
|
|
1449
|
+
yy = yv.reshape(-1, 1)
|
|
1450
|
+
per_sample = _LOSS_VALLOSS_FNS[loss_name](eta, yy, **_loss_params)
|
|
1451
|
+
if swv is not None:
|
|
1452
|
+
sw_col = swv.reshape(-1, 1)
|
|
1453
|
+
sw_sum = _to_float_scalar(xp.sum(swv))
|
|
1454
|
+
if sw_sum > 0:
|
|
1455
|
+
scores_arr = xp.sum(sw_col * per_sample, axis=0) / sw_sum
|
|
1456
|
+
else:
|
|
1457
|
+
scores_arr = xp.mean(per_sample, axis=0)
|
|
1458
|
+
else:
|
|
1459
|
+
scores_arr = xp.mean(per_sample, axis=0)
|
|
1460
|
+
scores = _to_numpy(scores_arr).tolist()
|
|
1461
|
+
else:
|
|
1462
|
+
scores = [_scalar_to_float(s) for s in score_params_path]
|
|
1463
|
+
|
|
1464
|
+
out = {
|
|
1465
|
+
"scores": np.asarray(scores, dtype=np.float64) if scores else None,
|
|
1466
|
+
"n_iter": np.asarray(iters, dtype=np.int64),
|
|
1467
|
+
}
|
|
1468
|
+
if return_path:
|
|
1469
|
+
out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
|
|
1470
|
+
out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
|
|
1471
|
+
return out
|
|
1472
|
+
|
|
1473
|
+
|
|
1474
|
+
def _scad_mcp_cv_path(
|
|
1475
|
+
loss_name,
|
|
1476
|
+
X_train,
|
|
1477
|
+
y_train,
|
|
1478
|
+
alpha_sorted,
|
|
1479
|
+
penalty_name,
|
|
1480
|
+
l1_ratio,
|
|
1481
|
+
max_iter,
|
|
1482
|
+
tol,
|
|
1483
|
+
device,
|
|
1484
|
+
X_val=None,
|
|
1485
|
+
y_val=None,
|
|
1486
|
+
sample_weight=None,
|
|
1487
|
+
val_sample_weight=None,
|
|
1488
|
+
return_path=False,
|
|
1489
|
+
max_lla_per_step=3,
|
|
1490
|
+
lla_tol=1e-4,
|
|
1491
|
+
loss_kwargs=None,
|
|
1492
|
+
):
|
|
1493
|
+
"""Warm-started SCAD/MCP alpha path for CV.
|
|
1494
|
+
|
|
1495
|
+
For each alpha: compute LLA weights from current coef, run FISTA with
|
|
1496
|
+
AdaptiveL1Penalty(weights=lla_w), warm-start from previous alpha.
|
|
1497
|
+
Avoids per-alpha model.fit() overhead.
|
|
1498
|
+
"""
|
|
1499
|
+
loss_name = str(loss_name).lower()
|
|
1500
|
+
penalty_name = str(penalty_name).lower()
|
|
1501
|
+
if penalty_name not in ("scad", "mcp"):
|
|
1502
|
+
return None
|
|
1503
|
+
if not _is_uniform_weight(sample_weight):
|
|
1504
|
+
warnings.warn(
|
|
1505
|
+
"_scad_mcp_cv_path: non-uniform sample_weight not supported, "
|
|
1506
|
+
"falling back to general CV path.",
|
|
1507
|
+
RuntimeWarning,
|
|
1508
|
+
stacklevel=2,
|
|
1509
|
+
)
|
|
1510
|
+
return None
|
|
1511
|
+
|
|
1512
|
+
from statgpu.solvers import fista_solver
|
|
1513
|
+
from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
|
|
1514
|
+
from statgpu.penalties import get_penalty, SCADPenalty, MCPPenalty
|
|
1515
|
+
from statgpu.penalties._adaptive_l1 import AdaptiveL1Penalty
|
|
1516
|
+
|
|
1517
|
+
backend = _backend_name_for_cv_device(device)
|
|
1518
|
+
from statgpu.backends._utils import _get_xp
|
|
1519
|
+
xp = _get_xp(backend)
|
|
1520
|
+
Xb = _to_backend_float64(X_train, backend)
|
|
1521
|
+
yb = _to_backend_float64(y_train, backend).reshape(-1)
|
|
1522
|
+
alphas = np.asarray(alpha_sorted, dtype=np.float64).ravel()
|
|
1523
|
+
n_samples, n_features = Xb.shape
|
|
1524
|
+
|
|
1525
|
+
# Augment X with intercept column
|
|
1526
|
+
from statgpu.backends._utils import xp_ones as _xp_ones_fn
|
|
1527
|
+
_ones = _xp_ones_fn((n_samples, 1), dtype=Xb.dtype, xp=xp, ref_arr=Xb)
|
|
1528
|
+
X_work = xp.concatenate([Xb, _ones], axis=1)
|
|
1529
|
+
|
|
1530
|
+
# Validation data
|
|
1531
|
+
if X_val is not None and y_val is not None:
|
|
1532
|
+
Xv = _to_backend_float64(X_val, backend)
|
|
1533
|
+
yv = _to_backend_float64(y_val, backend).reshape(-1)
|
|
1534
|
+
n_val = Xv.shape[0]
|
|
1535
|
+
if backend == "torch":
|
|
1536
|
+
ones_v = xp.ones((n_val, 1), dtype=Xv.dtype, device=Xv.device)
|
|
1537
|
+
X_val_work = xp.concatenate([Xv, ones_v], axis=1)
|
|
1538
|
+
elif backend == "cupy":
|
|
1539
|
+
ones_v = xp.ones((n_val, 1), dtype=Xv.dtype)
|
|
1540
|
+
X_val_work = xp.concatenate([Xv, ones_v], axis=1)
|
|
1541
|
+
else:
|
|
1542
|
+
ones_v = np.ones((n_val, 1), dtype=Xv.dtype)
|
|
1543
|
+
X_val_work = np.concatenate([Xv, ones_v], axis=1)
|
|
1544
|
+
else:
|
|
1545
|
+
X_val_work = yv = None
|
|
1546
|
+
|
|
1547
|
+
# Validation sample weights
|
|
1548
|
+
if val_sample_weight is not None and X_val_work is not None:
|
|
1549
|
+
swv = _to_backend_float64(val_sample_weight, backend).reshape(-1)
|
|
1550
|
+
else:
|
|
1551
|
+
swv = None
|
|
1552
|
+
|
|
1553
|
+
loss_fn = _resolve_loss_name(loss_name, loss_kwargs=loss_kwargs)
|
|
1554
|
+
|
|
1555
|
+
# Create SCAD/MCP penalty object
|
|
1556
|
+
if penalty_name == "scad":
|
|
1557
|
+
scad_penalty = SCADPenalty(alpha=float(alphas[0]))
|
|
1558
|
+
else:
|
|
1559
|
+
scad_penalty = MCPPenalty(alpha=float(alphas[0]))
|
|
1560
|
+
|
|
1561
|
+
# Precompute XtX and Lipschitz for squared_error
|
|
1562
|
+
_is_quadratic = (loss_name == "squared_error")
|
|
1563
|
+
X_mean = None
|
|
1564
|
+
y_mean = None
|
|
1565
|
+
if _is_quadratic:
|
|
1566
|
+
X_mean = xp.mean(X_work[:, :n_features], axis=0)
|
|
1567
|
+
y_mean = xp.mean(yb)
|
|
1568
|
+
Xc = X_work[:, :n_features] - X_mean
|
|
1569
|
+
yc = yb - y_mean
|
|
1570
|
+
XtX = Xc.T @ Xc / n_samples
|
|
1571
|
+
Xty = Xc.T @ yc / n_samples
|
|
1572
|
+
eig_max = _max_eigval_power(XtX)
|
|
1573
|
+
L_base = max(eig_max * 1.01, 1.0) # small safety factor for numerical stability
|
|
1574
|
+
else:
|
|
1575
|
+
# For GLM losses, compute Lipschitz from loss
|
|
1576
|
+
_zero = _zeros(n_features + 1, backend, ref_tensor=Xb)
|
|
1577
|
+
L_base = float(_to_numpy(loss_fn.lipschitz(X_work, _zero, y=yb)))
|
|
1578
|
+
_safety = getattr(loss_fn, '_lipschitz_safety', 1.0)
|
|
1579
|
+
if _safety > 1.0:
|
|
1580
|
+
L_base *= _safety
|
|
1581
|
+
|
|
1582
|
+
scores = []
|
|
1583
|
+
scores_dev = []
|
|
1584
|
+
coef_path = []
|
|
1585
|
+
intercept_path = []
|
|
1586
|
+
iters = []
|
|
1587
|
+
L_glm = None # Lipschitz constant for GLM losses (computed once)
|
|
1588
|
+
|
|
1589
|
+
# Pre-build loss-specific params (avoid dict construction in loop)
|
|
1590
|
+
_loss_params = {}
|
|
1591
|
+
if loss_name == "negative_binomial":
|
|
1592
|
+
_loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
|
|
1593
|
+
elif loss_name == "tweedie":
|
|
1594
|
+
_loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
|
|
1595
|
+
|
|
1596
|
+
# Initialize coef (warm-start from zeros or previous fold)
|
|
1597
|
+
coef = _zeros(n_features + 1, backend, ref_tensor=Xb)
|
|
1598
|
+
|
|
1599
|
+
# Pre-create inner penalty object (reuse across LLA iterations)
|
|
1600
|
+
inner_pen = AdaptiveL1Penalty(alpha=1.0)
|
|
1601
|
+
|
|
1602
|
+
for alpha in alphas:
|
|
1603
|
+
scad_penalty.alpha = float(alpha)
|
|
1604
|
+
|
|
1605
|
+
# LLA outer loop
|
|
1606
|
+
for lla_iter in range(max_lla_per_step):
|
|
1607
|
+
# Compute LLA weights from current coef (features only, intercept gets 0)
|
|
1608
|
+
lla_w_feat = scad_penalty.lla_weights(coef[:n_features])
|
|
1609
|
+
_zero_scalar = _zeros(1, backend, ref_tensor=coef)
|
|
1610
|
+
lla_w = xp.concatenate([lla_w_feat, _zero_scalar])
|
|
1611
|
+
|
|
1612
|
+
# Update weights in-place (avoid object creation overhead)
|
|
1613
|
+
inner_pen._weights = lla_w
|
|
1614
|
+
|
|
1615
|
+
coef_before_lla = _copy_arr(coef)
|
|
1616
|
+
iteration = -1 # default if max_iter=0
|
|
1617
|
+
|
|
1618
|
+
# FISTA inner solve with warm-start
|
|
1619
|
+
# Cap iterations for CV to keep SCAD/MCP paths fast
|
|
1620
|
+
_inner_max_iter = min(int(max_iter), _FISTA_MAX_ITER_CV)
|
|
1621
|
+
if _is_quadratic:
|
|
1622
|
+
# Squared error: use precomputed XtX
|
|
1623
|
+
step = 1.0 / L_base
|
|
1624
|
+
y_k = _copy_arr(coef)
|
|
1625
|
+
t_k = 1.0
|
|
1626
|
+
for iteration in range(_inner_max_iter):
|
|
1627
|
+
coef_old = _copy_arr(coef)
|
|
1628
|
+
grad = XtX @ y_k[:n_features] - Xty
|
|
1629
|
+
grad_full = xp.concatenate([grad, _zeros(1, backend, ref_tensor=grad)])
|
|
1630
|
+
|
|
1631
|
+
w = y_k - step * grad_full
|
|
1632
|
+
coef = inner_pen.proximal(w, step, backend=backend)
|
|
1633
|
+
|
|
1634
|
+
beta_mom, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
1635
|
+
y_k = coef + beta_mom * (coef - coef_old)
|
|
1636
|
+
|
|
1637
|
+
# Convergence check (device-side, every 10 iters for CV)
|
|
1638
|
+
if iteration % 10 == 0 and iteration > 0:
|
|
1639
|
+
delta = _abs_sum_dev(coef - coef_old)
|
|
1640
|
+
if _device_gt(tol, delta):
|
|
1641
|
+
break
|
|
1642
|
+
else:
|
|
1643
|
+
# GLM loss: direct FISTA loop with device-side convergence.
|
|
1644
|
+
# Precompute Lipschitz constant once (reuse across alphas).
|
|
1645
|
+
if L_glm is None:
|
|
1646
|
+
_zero = _zeros(n_features + 1, backend, ref_tensor=Xb)
|
|
1647
|
+
L_glm = float(_to_numpy(loss_fn.lipschitz(X_work, _zero, y=yb)))
|
|
1648
|
+
_safety = getattr(loss_fn, '_lipschitz_safety', 1.0)
|
|
1649
|
+
if _safety > 1.0:
|
|
1650
|
+
L_glm *= _safety
|
|
1651
|
+
# Y-scaling for exp-link families
|
|
1652
|
+
_loss_name_inner = getattr(loss_fn, 'name', '')
|
|
1653
|
+
_skip_ys = getattr(loss_fn, '_lipschitz_uses_y', False)
|
|
1654
|
+
if _loss_name_inner not in ('squared_error',) and not _skip_ys:
|
|
1655
|
+
_y_abs = np.abs(_to_numpy(yb))
|
|
1656
|
+
_y_mean = float(np.mean(_y_abs))
|
|
1657
|
+
_y_max = float(np.max(_y_abs))
|
|
1658
|
+
_y_scale = min(10.0, max(1.0, np.sqrt(_y_mean * _y_max)))
|
|
1659
|
+
if _y_scale > 1.0:
|
|
1660
|
+
L_glm *= _y_scale
|
|
1661
|
+
L_glm = max(L_glm, 1.0)
|
|
1662
|
+
|
|
1663
|
+
step = 1.0 / L_glm
|
|
1664
|
+
y_k = _copy_arr(coef)
|
|
1665
|
+
t_k = 1.0
|
|
1666
|
+
for iteration in range(_inner_max_iter):
|
|
1667
|
+
coef_old = _copy_arr(coef)
|
|
1668
|
+
|
|
1669
|
+
# Gradient: loss.gradient(X, y, coef)
|
|
1670
|
+
grad = loss_fn.gradient(X_work, yb, y_k)
|
|
1671
|
+
|
|
1672
|
+
# Proximal step
|
|
1673
|
+
w = y_k - step * grad
|
|
1674
|
+
coef = inner_pen.proximal(w, step, backend=backend)
|
|
1675
|
+
|
|
1676
|
+
# Momentum
|
|
1677
|
+
beta_mom, t_k = _nesterov_momentum(t_k, beta_cap=0.5)
|
|
1678
|
+
y_k = coef + beta_mom * (coef - coef_old)
|
|
1679
|
+
|
|
1680
|
+
# Convergence check (device-side, every 10 iters for CV)
|
|
1681
|
+
if iteration % 10 == 0 and iteration > 0:
|
|
1682
|
+
delta = _abs_sum_dev(coef - coef_old)
|
|
1683
|
+
if _device_gt(tol, delta):
|
|
1684
|
+
break
|
|
1685
|
+
|
|
1686
|
+
# LLA convergence check
|
|
1687
|
+
delta = _abs_sum_dev(coef - coef_before_lla)
|
|
1688
|
+
if _device_gt(lla_tol, delta):
|
|
1689
|
+
break
|
|
1690
|
+
|
|
1691
|
+
# Extract coef and compute intercept from centered-data fit.
|
|
1692
|
+
# For squared_error, the FISTA loop works on centered X/y, so
|
|
1693
|
+
# coef[n_features] stays at zero. Compute the correct intercept.
|
|
1694
|
+
if backend == "torch":
|
|
1695
|
+
coef_feat = coef[:n_features]
|
|
1696
|
+
coef_np = coef_feat.detach().cpu().numpy()
|
|
1697
|
+
elif backend == "cupy":
|
|
1698
|
+
coef_feat = coef[:n_features]
|
|
1699
|
+
coef_np = coef_feat.get()
|
|
1700
|
+
else:
|
|
1701
|
+
coef_feat = coef[:n_features]
|
|
1702
|
+
coef_np = coef_feat.copy()
|
|
1703
|
+
|
|
1704
|
+
if _is_quadratic and X_mean is not None:
|
|
1705
|
+
# intercept = y_mean - X_mean @ coef_features (from centering)
|
|
1706
|
+
intercept = float(y_mean - float(_to_numpy(xp.dot(X_mean, coef_feat))))
|
|
1707
|
+
# Update coef[n_features] so validation uses correct intercept
|
|
1708
|
+
coef[n_features] = intercept
|
|
1709
|
+
else:
|
|
1710
|
+
if backend == "torch":
|
|
1711
|
+
intercept = float(coef[n_features].item())
|
|
1712
|
+
elif backend == "cupy":
|
|
1713
|
+
intercept = float(coef[n_features].get())
|
|
1714
|
+
else:
|
|
1715
|
+
intercept = float(coef[n_features])
|
|
1716
|
+
|
|
1717
|
+
# Validation loss on device (weighted if val_sample_weight provided)
|
|
1718
|
+
if X_val_work is not None:
|
|
1719
|
+
if swv is not None:
|
|
1720
|
+
# Per-sample weighted loss
|
|
1721
|
+
eta_v = X_val_work @ coef
|
|
1722
|
+
if loss_name == "squared_error":
|
|
1723
|
+
per_sample = (yv - eta_v) ** 2
|
|
1724
|
+
else:
|
|
1725
|
+
per_sample = _LOSS_VALLOSS_FNS[loss_name](eta_v, yv, **_loss_params)
|
|
1726
|
+
sw_sum = _to_float_scalar(xp.sum(swv))
|
|
1727
|
+
val_loss = _to_float_scalar(xp.sum(swv * per_sample)) / max(sw_sum, 1e-15)
|
|
1728
|
+
else:
|
|
1729
|
+
val_loss = loss_fn.value(X_val_work, yv, coef)
|
|
1730
|
+
# Normalize to Python float to avoid mixing types in scores_dev
|
|
1731
|
+
scores_dev.append(float(val_loss) if not isinstance(val_loss, float) else val_loss)
|
|
1732
|
+
|
|
1733
|
+
if return_path:
|
|
1734
|
+
coef_path.append(coef_np)
|
|
1735
|
+
intercept_path.append(intercept)
|
|
1736
|
+
iters.append(iteration + 1)
|
|
1737
|
+
|
|
1738
|
+
# Batch sync validation scores (all values are Python floats)
|
|
1739
|
+
if scores_dev:
|
|
1740
|
+
scores = [float(s) for s in scores_dev]
|
|
1741
|
+
|
|
1742
|
+
out = {
|
|
1743
|
+
"scores": np.asarray(scores, dtype=np.float64) if scores else None,
|
|
1744
|
+
"n_iter": np.asarray(iters, dtype=np.int64),
|
|
1745
|
+
}
|
|
1746
|
+
if return_path:
|
|
1747
|
+
out["coef"] = np.vstack(coef_path).astype(np.float64, copy=False)
|
|
1748
|
+
out["intercept"] = np.asarray(intercept_path, dtype=np.float64)
|
|
1749
|
+
return out
|
|
1750
|
+
|
|
1751
|
+
|
|
1752
|
+
# ---------------------------------------------------------------------------
|
|
1753
|
+
# Data-driven GPU device selection thresholds for CV
|
|
1754
|
+
# ---------------------------------------------------------------------------
|
|
1755
|
+
# Each entry: (loss, penalties, min_nx, min_features, reason, or_min_features)
|
|
1756
|
+
# - loss: loss name or None (matches any non-squared-error)
|
|
1757
|
+
# - penalties: tuple of penalty names that trigger GPU evaluation
|
|
1758
|
+
# - min_nx: minimum n_samples * n_features to consider GPU
|
|
1759
|
+
# - min_features: minimum n_features (0 = no feature threshold)
|
|
1760
|
+
# - reason: explanation string for the auto-selection decision
|
|
1761
|
+
# - or_min_features: if >0, also allow GPU when n_features >= or_min_features
|
|
1762
|
+
# AND n_samples*n_features >= 1_000_000 (OR with primary cond)
|
|
1763
|
+
# If the condition is met and torch.cuda is available → "torch", else → "cpu".
|
|
1764
|
+
_CV_DEVICE_THRESHOLDS = [
|
|
1765
|
+
("squared_error", ("l1", "elasticnet", "en"), 1_000_000, 256,
|
|
1766
|
+
"medium squared-error sparse CV benefits from batched torch alpha path", 0),
|
|
1767
|
+
(None, ("scad", "mcp"), 1_000_000, 0,
|
|
1768
|
+
"large GLM SCAD/MCP CV benefits from torch async FISTA", 0),
|
|
1769
|
+
("logistic", ("l1", "elasticnet", "en"), 1_000_000, 500,
|
|
1770
|
+
"high-dimensional logistic sparse CV benefits from torch", 0),
|
|
1771
|
+
("logistic", ("l1", "elasticnet", "en"), 500_000, 100,
|
|
1772
|
+
"medium logistic sparse CV benefits from fold-batched torch path", 0),
|
|
1773
|
+
("poisson", ("l1", "elasticnet", "en"), 1_000_000, 500,
|
|
1774
|
+
"high-dimensional poisson sparse CV benefits from torch", 0),
|
|
1775
|
+
("gamma", ("l1", "elasticnet", "en"), 2_000_000, 500,
|
|
1776
|
+
"large high-dimensional gamma sparse CV benefits from torch", 0),
|
|
1777
|
+
("inverse_gaussian", ("l1", "elasticnet", "en"), 2_000_000, 500,
|
|
1778
|
+
"large high-dimensional inverse-gaussian sparse CV benefits from torch", 0),
|
|
1779
|
+
("tweedie", ("l1", "elasticnet", "en"), 300_000, 0,
|
|
1780
|
+
"medium tweedie sparse CV is faster on torch", 0),
|
|
1781
|
+
]
|
|
1782
|
+
# Special: always-CPU losses (regardless of problem size)
|
|
1783
|
+
_CV_DEVICE_ALWAYS_CPU = {
|
|
1784
|
+
"negative_binomial": "negative-binomial CV is faster on CPU for current benchmarked sizes",
|
|
1785
|
+
}
|
|
1786
|
+
|
|
1787
|
+
|
|
1788
|
+
class PenalizedGLM_CV(CVEstimatorBase):
|
|
1789
|
+
"""Cross-validated penalized GLM supporting all loss + penalty combinations."""
|
|
1790
|
+
|
|
1791
|
+
def __init__(
|
|
1792
|
+
self,
|
|
1793
|
+
loss: str = 'squared_error',
|
|
1794
|
+
penalty: str = 'l2',
|
|
1795
|
+
alpha_grid=None,
|
|
1796
|
+
n_alphas: int = 100,
|
|
1797
|
+
l1_ratio: float = 0.5,
|
|
1798
|
+
cv: int = 5,
|
|
1799
|
+
cv_splits=None,
|
|
1800
|
+
random_state: Optional[int] = 0,
|
|
1801
|
+
device: Union[str, Device] = Device.AUTO,
|
|
1802
|
+
max_iter: int = 1000,
|
|
1803
|
+
tol: float = 1e-4,
|
|
1804
|
+
solver: str = 'auto',
|
|
1805
|
+
cv_strategy: str = "strict",
|
|
1806
|
+
acknowledge_approx: bool = False,
|
|
1807
|
+
refine_top_k: int = 3,
|
|
1808
|
+
loss_kwargs: Optional[dict] = None,
|
|
1809
|
+
):
|
|
1810
|
+
super().__init__(cv=cv, random_state=random_state, device=device)
|
|
1811
|
+
self.cv_splits = cv_splits
|
|
1812
|
+
cv_strategy = str(cv_strategy).lower()
|
|
1813
|
+
if cv_strategy not in ("strict", "two_stage"):
|
|
1814
|
+
raise ValueError(
|
|
1815
|
+
"cv_strategy must be either 'strict' or 'two_stage', "
|
|
1816
|
+
f"got {cv_strategy!r}."
|
|
1817
|
+
)
|
|
1818
|
+
if int(refine_top_k) < 1:
|
|
1819
|
+
raise ValueError("refine_top_k must be a positive integer")
|
|
1820
|
+
self.loss = loss
|
|
1821
|
+
self._loss_kwargs = loss_kwargs or {}
|
|
1822
|
+
self.penalty = penalty
|
|
1823
|
+
self._alpha_grid_input = alpha_grid
|
|
1824
|
+
self.n_alphas = n_alphas
|
|
1825
|
+
self.l1_ratio = l1_ratio
|
|
1826
|
+
self.max_iter = max_iter
|
|
1827
|
+
self.tol = tol
|
|
1828
|
+
self.solver = solver
|
|
1829
|
+
self.cv_strategy = cv_strategy
|
|
1830
|
+
self.acknowledge_approx = bool(acknowledge_approx)
|
|
1831
|
+
self.refine_top_k = int(refine_top_k)
|
|
1832
|
+
|
|
1833
|
+
self.alpha_ = None
|
|
1834
|
+
self.alpha_grid_ = None
|
|
1835
|
+
self.cv_strategy_ = None
|
|
1836
|
+
self.cv_selected_device_ = None
|
|
1837
|
+
self._cv_auto_reason_ = None
|
|
1838
|
+
|
|
1839
|
+
def _solver_for_cv(self, cv_device=None, X=None):
|
|
1840
|
+
"""Return the strict internal solver used by the CV loop."""
|
|
1841
|
+
solver = str(self.solver).lower()
|
|
1842
|
+
if solver != "auto":
|
|
1843
|
+
return solver
|
|
1844
|
+
from statgpu.linear_model.penalized._fit_mixin import _preferred_penalized_glm_solver
|
|
1845
|
+
|
|
1846
|
+
return _preferred_penalized_glm_solver(
|
|
1847
|
+
self.loss,
|
|
1848
|
+
self.penalty,
|
|
1849
|
+
backend_name=_backend_name_for_cv_device(
|
|
1850
|
+
self.device if cv_device is None else cv_device
|
|
1851
|
+
),
|
|
1852
|
+
l1_ratio=self.l1_ratio,
|
|
1853
|
+
cv_mode=True,
|
|
1854
|
+
problem_size=None if X is None else int(X.shape[0]) * int(X.shape[1]),
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
def _effective_cv_device(self, X, penalty_name, n_alphas):
|
|
1858
|
+
"""Resolve device for CV-level work; explicit devices are untouched."""
|
|
1859
|
+
self.cv_selected_device_ = self.device
|
|
1860
|
+
self._cv_auto_reason_ = None
|
|
1861
|
+
if _device_to_name(self.device) != "auto":
|
|
1862
|
+
return self.device
|
|
1863
|
+
|
|
1864
|
+
n_samples, n_features = X.shape
|
|
1865
|
+
penalty_name = str(penalty_name).lower()
|
|
1866
|
+
loss_name = str(self.loss).lower()
|
|
1867
|
+
nx = int(n_samples) * int(n_features)
|
|
1868
|
+
|
|
1869
|
+
# Small problems: always CPU
|
|
1870
|
+
if nx < _SMALL_PROBLEM_THRESHOLD:
|
|
1871
|
+
self.cv_selected_device_ = "cpu"
|
|
1872
|
+
self._cv_auto_reason_ = "small CV problem is faster on CPU"
|
|
1873
|
+
return "cpu"
|
|
1874
|
+
|
|
1875
|
+
# Always-CPU losses
|
|
1876
|
+
if loss_name in _CV_DEVICE_ALWAYS_CPU and penalty_name in ("l2", "l1", "elasticnet", "en"):
|
|
1877
|
+
self.cv_selected_device_ = "cpu"
|
|
1878
|
+
self._cv_auto_reason_ = _CV_DEVICE_ALWAYS_CPU[loss_name]
|
|
1879
|
+
return "cpu"
|
|
1880
|
+
|
|
1881
|
+
# Data-driven threshold lookup
|
|
1882
|
+
for rule_loss, rule_penalties, min_nx, min_features, reason, or_min_feat in _CV_DEVICE_THRESHOLDS:
|
|
1883
|
+
loss_match = (rule_loss is None and loss_name != "squared_error") or rule_loss == loss_name
|
|
1884
|
+
if loss_match and penalty_name in rule_penalties:
|
|
1885
|
+
# Primary condition: nx >= min_nx AND n_features >= min_features
|
|
1886
|
+
cond = nx >= min_nx and int(n_features) >= min_features
|
|
1887
|
+
# OR condition: n_features >= or_min_feat AND nx >= 1_000_000
|
|
1888
|
+
if not cond and or_min_feat > 0:
|
|
1889
|
+
cond = int(n_features) >= or_min_feat and nx >= 1_000_000
|
|
1890
|
+
if cond and _torch_cuda_available():
|
|
1891
|
+
self.cv_selected_device_ = "torch"
|
|
1892
|
+
self._cv_auto_reason_ = reason
|
|
1893
|
+
return "torch"
|
|
1894
|
+
self.cv_selected_device_ = "cpu"
|
|
1895
|
+
self._cv_auto_reason_ = reason.replace("benefits from", "is faster on CPU below break-even for")
|
|
1896
|
+
return "cpu"
|
|
1897
|
+
|
|
1898
|
+
# Fallback: large effective work → GPU
|
|
1899
|
+
continuation_factor = 20 if loss_name != "squared_error" and penalty_name in ("scad", "mcp") else 1
|
|
1900
|
+
effective_work = nx * int(self.cv) * int(n_alphas) * continuation_factor
|
|
1901
|
+
if effective_work < _GPU_BREAK_EVEN_THRESHOLD:
|
|
1902
|
+
self.cv_selected_device_ = "cpu"
|
|
1903
|
+
self._cv_auto_reason_ = "CV effective work is below GPU break-even"
|
|
1904
|
+
return "cpu"
|
|
1905
|
+
|
|
1906
|
+
# Resolve device: if AUTO, prefer torch when CUDA available, else cpu
|
|
1907
|
+
try:
|
|
1908
|
+
import torch
|
|
1909
|
+
if torch.cuda.is_available():
|
|
1910
|
+
self.cv_selected_device_ = "torch"
|
|
1911
|
+
self._cv_auto_reason_ = "GPU selected for large CV effective work"
|
|
1912
|
+
return "torch"
|
|
1913
|
+
except ImportError:
|
|
1914
|
+
pass
|
|
1915
|
+
try:
|
|
1916
|
+
import cupy
|
|
1917
|
+
self.cv_selected_device_ = "cupy"
|
|
1918
|
+
self._cv_auto_reason_ = "GPU selected for large CV effective work"
|
|
1919
|
+
return "cupy"
|
|
1920
|
+
except ImportError:
|
|
1921
|
+
pass
|
|
1922
|
+
self.cv_selected_device_ = "cpu"
|
|
1923
|
+
self._cv_auto_reason_ = "No GPU available, falling back to CPU"
|
|
1924
|
+
return "cpu"
|
|
1925
|
+
|
|
1926
|
+
def _generate_alpha_grid(self, X, y):
|
|
1927
|
+
"""Auto-generate alpha grid based on loss and penalty type."""
|
|
1928
|
+
from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
|
|
1929
|
+
|
|
1930
|
+
X_np = _to_numpy(X).astype(np.float64)
|
|
1931
|
+
y_np = _to_numpy(y).astype(np.float64).ravel()
|
|
1932
|
+
n = X_np.shape[0]
|
|
1933
|
+
|
|
1934
|
+
if self.loss == 'squared_error':
|
|
1935
|
+
# Gradient at null model (intercept = mean(y)): X'(y - mean(y)) / n
|
|
1936
|
+
alpha_max = float(np.max(np.abs(X_np.T @ (y_np - np.mean(y_np))))) / n
|
|
1937
|
+
elif self.loss == 'logistic':
|
|
1938
|
+
# Null model prediction: mu_null = mean(y)
|
|
1939
|
+
mu_null = np.mean(y_np)
|
|
1940
|
+
alpha_max = float(np.max(np.abs(X_np.T @ (y_np - mu_null)))) / n
|
|
1941
|
+
else:
|
|
1942
|
+
try:
|
|
1943
|
+
model = PenalizedGeneralizedLinearModel(
|
|
1944
|
+
loss=self.loss, penalty='l2', alpha=0.0,
|
|
1945
|
+
device='cpu', compute_inference=False, max_iter=5,
|
|
1946
|
+
loss_kwargs=getattr(self, '_loss_kwargs', None),
|
|
1947
|
+
)
|
|
1948
|
+
model.fit(X_np, y_np)
|
|
1949
|
+
grad = X_np.T @ (y_np - _to_numpy(model.predict(X_np))) / n
|
|
1950
|
+
alpha_max = float(np.max(np.abs(grad)))
|
|
1951
|
+
except Exception as e:
|
|
1952
|
+
warnings.warn(
|
|
1953
|
+
f"Alpha grid estimation failed ({e}), using alpha_max=1.0",
|
|
1954
|
+
RuntimeWarning,
|
|
1955
|
+
stacklevel=2,
|
|
1956
|
+
)
|
|
1957
|
+
alpha_max = 1.0
|
|
1958
|
+
|
|
1959
|
+
# For elasticnet, the L1 component threshold is alpha*l1_ratio,
|
|
1960
|
+
# so alpha_max should be scaled by 1/l1_ratio
|
|
1961
|
+
if self.penalty == 'elasticnet' and hasattr(self, 'l1_ratio'):
|
|
1962
|
+
_l1r = max(float(self.l1_ratio), 1e-10)
|
|
1963
|
+
alpha_max = alpha_max / _l1r
|
|
1964
|
+
|
|
1965
|
+
if alpha_max <= 0:
|
|
1966
|
+
warnings.warn(
|
|
1967
|
+
f"Alpha grid estimation returned {alpha_max}, using alpha_max=1.0",
|
|
1968
|
+
RuntimeWarning,
|
|
1969
|
+
stacklevel=2,
|
|
1970
|
+
)
|
|
1971
|
+
alpha_max = 1.0
|
|
1972
|
+
|
|
1973
|
+
grid = np.geomspace(alpha_max, max(alpha_max * 1e-4, 1e-12), self.n_alphas)
|
|
1974
|
+
return grid
|
|
1975
|
+
|
|
1976
|
+
def _solve_ridge_fold_batch(self, X_train, y_train, X_val, y_val, alphas):
|
|
1977
|
+
"""Batch solve Ridge CV for all alphas using eigendecomposition."""
|
|
1978
|
+
X_train_np = _to_numpy(X_train).astype(np.float64)
|
|
1979
|
+
y_train_np = _to_numpy(y_train).astype(np.float64).ravel()
|
|
1980
|
+
X_val_np = _to_numpy(X_val).astype(np.float64)
|
|
1981
|
+
y_val_np = _to_numpy(y_val).astype(np.float64).ravel()
|
|
1982
|
+
alphas_np = _to_numpy(alphas).astype(np.float64).ravel()
|
|
1983
|
+
return _ridge_eig_batch(X_train_np, y_train_np, X_val_np, y_val_np, alphas_np)
|
|
1984
|
+
|
|
1985
|
+
def _evaluate_single(self, model, X_val, y_val, loss_fn=None, X_val_np=None, y_val_np=None, sample_weight=None):
|
|
1986
|
+
"""Evaluate a fitted model on validation data, return validation loss.
|
|
1987
|
+
|
|
1988
|
+
Parameters
|
|
1989
|
+
----------
|
|
1990
|
+
loss_fn : optional, pre-resolved loss function (avoids repeated import)
|
|
1991
|
+
X_val_np, y_val_np : optional, pre-cached numpy validation data (avoids D2H)
|
|
1992
|
+
sample_weight : optional, per-sample weights for weighted validation loss
|
|
1993
|
+
"""
|
|
1994
|
+
from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
|
|
1995
|
+
|
|
1996
|
+
if loss_fn is None:
|
|
1997
|
+
loss_fn = _resolve_loss_name(self.loss)
|
|
1998
|
+
if X_val_np is None:
|
|
1999
|
+
X_val_np = _to_numpy(X_val).astype(np.float64)
|
|
2000
|
+
if y_val_np is None:
|
|
2001
|
+
y_val_np = _to_numpy(y_val).astype(np.float64).ravel()
|
|
2002
|
+
n_val = X_val_np.shape[0]
|
|
2003
|
+
|
|
2004
|
+
try:
|
|
2005
|
+
val_loss = _evaluate_loss_numpy(
|
|
2006
|
+
self.loss,
|
|
2007
|
+
loss_fn,
|
|
2008
|
+
X_val_np,
|
|
2009
|
+
y_val_np,
|
|
2010
|
+
_to_numpy(model.coef_).ravel(),
|
|
2011
|
+
float(model.intercept_),
|
|
2012
|
+
model.fit_intercept,
|
|
2013
|
+
sample_weight=sample_weight,
|
|
2014
|
+
)
|
|
2015
|
+
except Exception:
|
|
2016
|
+
# Fallback: use loss_fn.value() for correct loss, not raw MSE
|
|
2017
|
+
try:
|
|
2018
|
+
if model.fit_intercept:
|
|
2019
|
+
X_design = np.column_stack([np.ones(n_val), X_val_np])
|
|
2020
|
+
coef_full = np.concatenate([[float(model.intercept_)], _to_numpy(model.coef_).ravel()])
|
|
2021
|
+
else:
|
|
2022
|
+
X_design = X_val_np
|
|
2023
|
+
coef_full = _to_numpy(model.coef_).ravel()
|
|
2024
|
+
val_loss = float(loss_fn.value(X_design, y_val_np, coef_full))
|
|
2025
|
+
except Exception:
|
|
2026
|
+
y_pred_np = _to_numpy(model.predict(X_val_np)).ravel()
|
|
2027
|
+
val_loss = float(np.mean((y_val_np - y_pred_np) ** 2))
|
|
2028
|
+
warnings.warn(
|
|
2029
|
+
f"_evaluate_single: loss evaluation failed for '{self.loss}', "
|
|
2030
|
+
f"falling back to MSE. CV scores may be inaccurate for non-Gaussian losses.",
|
|
2031
|
+
RuntimeWarning,
|
|
2032
|
+
stacklevel=2,
|
|
2033
|
+
)
|
|
2034
|
+
|
|
2035
|
+
return val_loss
|
|
2036
|
+
|
|
2037
|
+
@staticmethod
|
|
2038
|
+
def _populate_refit_model(model, coef, intercept, X, device, n_iter=None):
|
|
2039
|
+
"""Set standard attributes on a refit model from path results."""
|
|
2040
|
+
model.coef_ = np.asarray(coef, dtype=np.float64)
|
|
2041
|
+
model.intercept_ = float(intercept)
|
|
2042
|
+
if n_iter is not None:
|
|
2043
|
+
model.n_iter_ = int(n_iter)
|
|
2044
|
+
model._params = np.concatenate([[float(intercept)], np.asarray(coef, dtype=np.float64)])
|
|
2045
|
+
model._nobs = int(X.shape[0])
|
|
2046
|
+
n_params = int(X.shape[1]) + (1 if bool(getattr(model, 'fit_intercept', True)) else 0)
|
|
2047
|
+
model._df_resid = int(X.shape[0]) - n_params
|
|
2048
|
+
model._selected_backend_name = _backend_name_for_cv_device(device)
|
|
2049
|
+
model._fitted = True
|
|
2050
|
+
return model
|
|
2051
|
+
|
|
2052
|
+
def _refit_best(self, X, y, best_alpha, sample_weight=None):
|
|
2053
|
+
"""Refit on full data with best alpha.
|
|
2054
|
+
|
|
2055
|
+
For squared_error + l2, uses eigendecomposition to match the CV path
|
|
2056
|
+
exactly, avoiding precision mismatch between CV and refit solvers.
|
|
2057
|
+
"""
|
|
2058
|
+
from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
|
|
2059
|
+
|
|
2060
|
+
# Resolve refit device (used by Ridge and general paths)
|
|
2061
|
+
refit_device = self.device
|
|
2062
|
+
if _device_to_name(self.device) == "auto":
|
|
2063
|
+
refit_device = getattr(self, "_cv_selected_device_", self.device) or self.device
|
|
2064
|
+
|
|
2065
|
+
# For Ridge: use eigendecomposition to match CV path exactly.
|
|
2066
|
+
# Supports weighted Ridge via weighted eigensolve (same O(p³) cost).
|
|
2067
|
+
if self.loss == 'squared_error' and self.penalty == 'l2':
|
|
2068
|
+
X_np = _to_numpy(X).astype(np.float64)
|
|
2069
|
+
y_np = _to_numpy(y).astype(np.float64).ravel()
|
|
2070
|
+
sw_np = _to_numpy(sample_weight).astype(np.float64).ravel() if sample_weight is not None else None
|
|
2071
|
+
coef, intercept = _ridge_eig_single(X_np, y_np, best_alpha, sample_weight=sw_np)
|
|
2072
|
+
model = PenalizedGeneralizedLinearModel(
|
|
2073
|
+
loss='squared_error', penalty='l2', alpha=best_alpha,
|
|
2074
|
+
device=refit_device, compute_inference=False,
|
|
2075
|
+
max_iter=self.max_iter, tol=self.tol,
|
|
2076
|
+
loss_kwargs=getattr(self, '_loss_kwargs', None),
|
|
2077
|
+
)
|
|
2078
|
+
return self._populate_refit_model(model, coef, intercept, X, refit_device)
|
|
2079
|
+
|
|
2080
|
+
can_infer = (self.loss == 'squared_error' and self.penalty == 'l2')
|
|
2081
|
+
penalty_name = str(self.penalty).lower()
|
|
2082
|
+
alpha_arr = np.asarray([best_alpha], dtype=np.float64)
|
|
2083
|
+
|
|
2084
|
+
# Try specialized refit paths (each returns model or None)
|
|
2085
|
+
refit_paths = []
|
|
2086
|
+
if self.loss == "logistic" and penalty_name in ("l1", "elasticnet", "en"):
|
|
2087
|
+
refit_paths.append(lambda: _logistic_sparse_cv_path(
|
|
2088
|
+
X, y, alpha_arr, penalty_name, self.l1_ratio,
|
|
2089
|
+
_logistic_sparse_effective_max_iter(self.max_iter, refit_device, penalty_name, refit=True),
|
|
2090
|
+
self.tol, refit_device, sample_weight=sample_weight,
|
|
2091
|
+
))
|
|
2092
|
+
if self.loss == "squared_error" and penalty_name in ("l1", "elasticnet", "en"):
|
|
2093
|
+
refit_paths.append(lambda: _squared_error_sparse_cv_path(
|
|
2094
|
+
X, y, alpha_arr, penalty_name, self.l1_ratio,
|
|
2095
|
+
self.max_iter, self.tol, refit_device, sample_weight=sample_weight,
|
|
2096
|
+
))
|
|
2097
|
+
cv_solver = self._solver_for_cv(refit_device, X=X)
|
|
2098
|
+
if self._uses_glm_sparse_path(penalty_name, cv_solver):
|
|
2099
|
+
refit_paths.append(lambda: _glm_sparse_cv_path(
|
|
2100
|
+
self.loss, X, y, alpha_arr, penalty_name, self.l1_ratio,
|
|
2101
|
+
self.max_iter, self.tol, refit_device,
|
|
2102
|
+
return_path=True, solver_name=cv_solver, cv_mode=False,
|
|
2103
|
+
sample_weight=sample_weight,
|
|
2104
|
+
))
|
|
2105
|
+
|
|
2106
|
+
for get_path in refit_paths:
|
|
2107
|
+
path = get_path()
|
|
2108
|
+
if path is not None:
|
|
2109
|
+
model = PenalizedGeneralizedLinearModel(
|
|
2110
|
+
loss=self.loss, penalty=self.penalty, alpha=best_alpha,
|
|
2111
|
+
l1_ratio=self.l1_ratio, device=refit_device,
|
|
2112
|
+
compute_inference=False, max_iter=self.max_iter,
|
|
2113
|
+
tol=self.tol, solver=cv_solver,
|
|
2114
|
+
loss_kwargs=getattr(self, '_loss_kwargs', None),
|
|
2115
|
+
)
|
|
2116
|
+
return self._populate_refit_model(
|
|
2117
|
+
model, path["coef"][-1], path["intercept"][-1],
|
|
2118
|
+
X, refit_device, n_iter=path["n_iter"][-1],
|
|
2119
|
+
)
|
|
2120
|
+
|
|
2121
|
+
# General fallback: model.fit()
|
|
2122
|
+
model = PenalizedGeneralizedLinearModel(
|
|
2123
|
+
loss=self.loss, penalty=self.penalty, alpha=best_alpha,
|
|
2124
|
+
l1_ratio=self.l1_ratio, device=refit_device,
|
|
2125
|
+
compute_inference=can_infer, max_iter=self.max_iter,
|
|
2126
|
+
tol=self.tol, solver=cv_solver,
|
|
2127
|
+
loss_kwargs=getattr(self, '_loss_kwargs', None),
|
|
2128
|
+
)
|
|
2129
|
+
model.fit(X, y, sample_weight=sample_weight)
|
|
2130
|
+
return model
|
|
2131
|
+
|
|
2132
|
+
def _uses_glm_sparse_path(self, penalty_name, cv_solver):
|
|
2133
|
+
penalty_name = str(penalty_name).lower()
|
|
2134
|
+
cv_solver = str(cv_solver).lower()
|
|
2135
|
+
return (
|
|
2136
|
+
(
|
|
2137
|
+
(self.loss == "poisson" and penalty_name in ("l1", "elasticnet", "en"))
|
|
2138
|
+
or self.loss in ("gamma", "inverse_gaussian", "tweedie")
|
|
2139
|
+
or (self.loss == "negative_binomial" and cv_solver == "fista_bb")
|
|
2140
|
+
)
|
|
2141
|
+
and penalty_name in ("l1", "elasticnet", "en")
|
|
2142
|
+
and cv_solver in ("auto", "fista", "fista_bb")
|
|
2143
|
+
)
|
|
2144
|
+
|
|
2145
|
+
def _best_index_from_scores(self, mean_scores, alpha_grid, cv_solver):
|
|
2146
|
+
penalty_name = str(self.penalty).lower()
|
|
2147
|
+
loss_name = str(self.loss).lower()
|
|
2148
|
+
if loss_name == "poisson" and penalty_name in ("l1", "elasticnet", "en"):
|
|
2149
|
+
# Poisson sparse CV curves can be nearly flat at the low-alpha end.
|
|
2150
|
+
# CPU/CuPy/Torch validation scores may differ at ~1e-7 from
|
|
2151
|
+
# backend-level summation order, so treat those as ties and keep
|
|
2152
|
+
# selection deterministic toward stronger regularization.
|
|
2153
|
+
return _nanargmin_prefer_larger_alpha(
|
|
2154
|
+
mean_scores,
|
|
2155
|
+
alpha_grid,
|
|
2156
|
+
rel_tol=5e-7,
|
|
2157
|
+
abs_tol=1e-6,
|
|
2158
|
+
)
|
|
2159
|
+
if self._uses_glm_sparse_path(penalty_name, cv_solver):
|
|
2160
|
+
return _nanargmin_prefer_larger_alpha(
|
|
2161
|
+
mean_scores,
|
|
2162
|
+
alpha_grid,
|
|
2163
|
+
rel_tol=5e-6,
|
|
2164
|
+
abs_tol=1e-7,
|
|
2165
|
+
)
|
|
2166
|
+
return _nanargmin_prefer_larger_alpha(mean_scores, alpha_grid)
|
|
2167
|
+
|
|
2168
|
+
def _compute_cv_scores(
|
|
2169
|
+
self,
|
|
2170
|
+
X,
|
|
2171
|
+
y,
|
|
2172
|
+
alpha_grid,
|
|
2173
|
+
cv_device,
|
|
2174
|
+
folds,
|
|
2175
|
+
sample_weight=None,
|
|
2176
|
+
max_iter=None,
|
|
2177
|
+
tol=None,
|
|
2178
|
+
strict=True,
|
|
2179
|
+
):
|
|
2180
|
+
"""Compute CV scores for exactly the supplied alpha grid."""
|
|
2181
|
+
from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
|
|
2182
|
+
|
|
2183
|
+
alpha_grid = np.asarray(alpha_grid, dtype=np.float64).ravel()
|
|
2184
|
+
n_alphas = len(alpha_grid)
|
|
2185
|
+
penalty_name = str(self.penalty).lower()
|
|
2186
|
+
loss_name = str(self.loss).lower()
|
|
2187
|
+
device_name = _device_to_name(cv_device)
|
|
2188
|
+
max_iter = int(self.max_iter if max_iter is None else max_iter)
|
|
2189
|
+
tol = self.tol if tol is None else tol
|
|
2190
|
+
|
|
2191
|
+
# ── Fast path: Ridge eigendecomposition (CPU only, unweighted) ──
|
|
2192
|
+
_is_explicit_gpu = device_name in ("cuda", "torch")
|
|
2193
|
+
if loss_name == "squared_error" and penalty_name == "l2" and sample_weight is None and not _is_explicit_gpu:
|
|
2194
|
+
all_scores = np.full((len(folds), n_alphas), np.nan)
|
|
2195
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
2196
|
+
X_train = _slice_rows(X, train_idx)
|
|
2197
|
+
y_train = _slice_rows(y, train_idx)
|
|
2198
|
+
X_val = _slice_rows(X, val_idx)
|
|
2199
|
+
y_val = _slice_rows(y, val_idx)
|
|
2200
|
+
try:
|
|
2201
|
+
mse, _, _ = self._solve_ridge_fold_batch(
|
|
2202
|
+
X_train, y_train, X_val, y_val, alpha_grid,
|
|
2203
|
+
)
|
|
2204
|
+
all_scores[fold_idx, :] = mse
|
|
2205
|
+
except Exception as e:
|
|
2206
|
+
warnings.warn(
|
|
2207
|
+
f"Ridge eig batch failed for fold {fold_idx}: {e}",
|
|
2208
|
+
RuntimeWarning,
|
|
2209
|
+
stacklevel=2,
|
|
2210
|
+
)
|
|
2211
|
+
return all_scores
|
|
2212
|
+
|
|
2213
|
+
sort_idx = np.argsort(-alpha_grid)
|
|
2214
|
+
alpha_sorted = alpha_grid[sort_idx]
|
|
2215
|
+
all_scores = np.full((len(folds), n_alphas), np.nan)
|
|
2216
|
+
cv_solver = self._solver_for_cv(cv_device, X=X)
|
|
2217
|
+
|
|
2218
|
+
# ── Fast path: fold-batched CV (all folds at once, GPU only) ──
|
|
2219
|
+
use_fold_batch = (
|
|
2220
|
+
not strict
|
|
2221
|
+
and loss_name in _FOLD_BATCH_CONFIGS
|
|
2222
|
+
and penalty_name in ("l1", "elasticnet", "en")
|
|
2223
|
+
and device_name in ("torch", "cuda")
|
|
2224
|
+
)
|
|
2225
|
+
if use_fold_batch:
|
|
2226
|
+
try:
|
|
2227
|
+
path = _glm_sparse_cv_folds(
|
|
2228
|
+
X, y, folds, alpha_sorted, penalty_name, self.l1_ratio,
|
|
2229
|
+
max_iter, tol, loss_name, device_name,
|
|
2230
|
+
sample_weight=sample_weight,
|
|
2231
|
+
loss_kwargs=getattr(self, '_loss_kwargs', None),
|
|
2232
|
+
)
|
|
2233
|
+
if path is not None and path["scores"] is not None:
|
|
2234
|
+
all_scores[:, sort_idx] = path["scores"]
|
|
2235
|
+
return all_scores
|
|
2236
|
+
except Exception as e:
|
|
2237
|
+
warnings.warn(
|
|
2238
|
+
f"Fold-batched {loss_name} sparse CV failed on {device_name}; "
|
|
2239
|
+
f"falling back to per-fold path: {e}",
|
|
2240
|
+
RuntimeWarning,
|
|
2241
|
+
stacklevel=2,
|
|
2242
|
+
)
|
|
2243
|
+
|
|
2244
|
+
# ── Per-fold dispatch table ──
|
|
2245
|
+
# Each entry: (condition_fn, path_fn)
|
|
2246
|
+
# condition_fn(loss_name, penalty_name, cv_solver, strict) -> bool
|
|
2247
|
+
# path_fn(X_train, y_train, alpha_sorted, ..., X_val, y_val, sw_train, sw_val) -> dict or None
|
|
2248
|
+
|
|
2249
|
+
def _cond_scad_mcp(loss_name, penalty_name, cv_solver, strict):
|
|
2250
|
+
return penalty_name in ("scad", "mcp") and (loss_name == "squared_error" or not strict)
|
|
2251
|
+
|
|
2252
|
+
def _path_scad_mcp(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
|
|
2253
|
+
max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
|
|
2254
|
+
return _scad_mcp_cv_path(
|
|
2255
|
+
loss_name, X_train, y_train, alpha_sorted, penalty_name,
|
|
2256
|
+
l1_ratio, max_iter, tol, cv_device,
|
|
2257
|
+
X_val=X_val, y_val=y_val, sample_weight=sw_train,
|
|
2258
|
+
val_sample_weight=sw_val,
|
|
2259
|
+
)
|
|
2260
|
+
|
|
2261
|
+
def _cond_logistic(loss_name, penalty_name, cv_solver, strict):
|
|
2262
|
+
return loss_name == "logistic" and penalty_name in ("l1", "elasticnet", "en")
|
|
2263
|
+
|
|
2264
|
+
def _path_logistic(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
|
|
2265
|
+
max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
|
|
2266
|
+
return _logistic_sparse_cv_path(
|
|
2267
|
+
X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
|
|
2268
|
+
max_iter, tol, cv_device,
|
|
2269
|
+
X_val=X_val, y_val=y_val, sample_weight=sw_train,
|
|
2270
|
+
val_sample_weight=sw_val, return_path=False,
|
|
2271
|
+
)
|
|
2272
|
+
|
|
2273
|
+
def _cond_squared(loss_name, penalty_name, cv_solver, strict):
|
|
2274
|
+
return loss_name == "squared_error" and penalty_name in ("l1", "elasticnet", "en")
|
|
2275
|
+
|
|
2276
|
+
def _path_squared(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
|
|
2277
|
+
max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
|
|
2278
|
+
return _squared_error_sparse_cv_path(
|
|
2279
|
+
X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
|
|
2280
|
+
max_iter, tol, cv_device,
|
|
2281
|
+
X_val=X_val, y_val=y_val, sample_weight=sw_train,
|
|
2282
|
+
val_sample_weight=sw_val, return_path=False,
|
|
2283
|
+
)
|
|
2284
|
+
|
|
2285
|
+
def _cond_glm_sparse(loss_name, penalty_name, cv_solver, strict):
|
|
2286
|
+
return self._uses_glm_sparse_path(penalty_name, cv_solver)
|
|
2287
|
+
|
|
2288
|
+
def _path_glm_sparse(X_train, y_train, alpha_sorted, penalty_name, l1_ratio,
|
|
2289
|
+
max_iter, tol, cv_device, X_val, y_val, sw_train, sw_val):
|
|
2290
|
+
return _glm_sparse_cv_path(
|
|
2291
|
+
loss_name, X_train, y_train, alpha_sorted, penalty_name,
|
|
2292
|
+
l1_ratio, max_iter, tol, cv_device,
|
|
2293
|
+
X_val=X_val, y_val=y_val, sample_weight=sw_train,
|
|
2294
|
+
val_sample_weight=sw_val, return_path=False,
|
|
2295
|
+
solver_name=cv_solver, cv_mode=not strict,
|
|
2296
|
+
)
|
|
2297
|
+
|
|
2298
|
+
_per_fold_paths = [
|
|
2299
|
+
(_cond_scad_mcp, _path_scad_mcp),
|
|
2300
|
+
(_cond_logistic, _path_logistic),
|
|
2301
|
+
(_cond_squared, _path_squared),
|
|
2302
|
+
(_cond_glm_sparse, _path_glm_sparse),
|
|
2303
|
+
]
|
|
2304
|
+
|
|
2305
|
+
# Pre-check which paths are active for this loss/penalty combo
|
|
2306
|
+
active_paths = [(cond, path_fn) for cond, path_fn in _per_fold_paths
|
|
2307
|
+
if cond(loss_name, penalty_name, cv_solver, strict)]
|
|
2308
|
+
|
|
2309
|
+
# ── Per-fold loop ──
|
|
2310
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
2311
|
+
X_train = _slice_rows(X, train_idx)
|
|
2312
|
+
y_train = _slice_rows(y, train_idx)
|
|
2313
|
+
X_val = _slice_rows(X, val_idx)
|
|
2314
|
+
y_val = _slice_rows(y, val_idx)
|
|
2315
|
+
sw_train = _slice_rows(sample_weight, train_idx) if sample_weight is not None else None
|
|
2316
|
+
sw_val = _slice_rows(sample_weight, val_idx) if sample_weight is not None else None
|
|
2317
|
+
|
|
2318
|
+
# Try each specialized path in order
|
|
2319
|
+
fold_handled = False
|
|
2320
|
+
for cond_fn, path_fn in active_paths:
|
|
2321
|
+
try:
|
|
2322
|
+
path = path_fn(
|
|
2323
|
+
X_train, y_train, alpha_sorted, penalty_name,
|
|
2324
|
+
self.l1_ratio, max_iter, tol, cv_device,
|
|
2325
|
+
X_val=X_val, y_val=y_val,
|
|
2326
|
+
sw_train=sw_train, sw_val=sw_val,
|
|
2327
|
+
)
|
|
2328
|
+
if path is not None and path["scores"] is not None:
|
|
2329
|
+
all_scores[fold_idx, sort_idx] = path["scores"]
|
|
2330
|
+
fold_handled = True
|
|
2331
|
+
break
|
|
2332
|
+
except Exception as e:
|
|
2333
|
+
warnings.warn(
|
|
2334
|
+
f"{path_fn.__name__} failed for {loss_name}+{penalty_name} "
|
|
2335
|
+
f"fold {fold_idx}: {e}",
|
|
2336
|
+
RuntimeWarning,
|
|
2337
|
+
stacklevel=2,
|
|
2338
|
+
)
|
|
2339
|
+
continue
|
|
2340
|
+
|
|
2341
|
+
if fold_handled:
|
|
2342
|
+
continue
|
|
2343
|
+
|
|
2344
|
+
# ── General fallback: model.fit() per alpha ──
|
|
2345
|
+
self._cv_fold_general(
|
|
2346
|
+
all_scores, fold_idx, sort_idx, alpha_sorted,
|
|
2347
|
+
loss_name, cv_device, cv_solver, strict,
|
|
2348
|
+
X_train, y_train, X_val, y_val,
|
|
2349
|
+
sw_train, sw_val, max_iter, tol,
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
return all_scores
|
|
2353
|
+
|
|
2354
|
+
def _cv_fold_general(
|
|
2355
|
+
self, all_scores, fold_idx, sort_idx, alpha_sorted,
|
|
2356
|
+
loss_name, cv_device, cv_solver, strict,
|
|
2357
|
+
X_train, y_train, X_val, y_val,
|
|
2358
|
+
sw_train, sw_val, max_iter, tol,
|
|
2359
|
+
):
|
|
2360
|
+
"""General per-fold CV path: model.fit() per alpha with warm-start."""
|
|
2361
|
+
from statgpu.linear_model.penalized._base import PenalizedGeneralizedLinearModel
|
|
2362
|
+
from statgpu.linear_model.penalized._fit_mixin import _resolve_loss_name
|
|
2363
|
+
|
|
2364
|
+
penalty_name = str(self.penalty).lower()
|
|
2365
|
+
device_name = _device_to_name(cv_device)
|
|
2366
|
+
|
|
2367
|
+
X_val_np = _to_numpy(X_val).astype(np.float64)
|
|
2368
|
+
y_val_np = _to_numpy(y_val).astype(np.float64).ravel()
|
|
2369
|
+
loss_fn = _resolve_loss_name(loss_name)
|
|
2370
|
+
|
|
2371
|
+
# Disable warm-start for SCAD/MCP on non-squared-error losses
|
|
2372
|
+
_is_scad_mcp_non_se = penalty_name in ("scad", "mcp") and loss_name != "squared_error"
|
|
2373
|
+
use_warm_start = not _is_scad_mcp_non_se
|
|
2374
|
+
use_lla_path_cv = (
|
|
2375
|
+
not strict and loss_name != "squared_error" and penalty_name in ("scad", "mcp")
|
|
2376
|
+
)
|
|
2377
|
+
|
|
2378
|
+
# Transfer to GPU if needed
|
|
2379
|
+
if device_name in ("cuda", "torch"):
|
|
2380
|
+
fold_backend = _backend_name_for_cv_device(cv_device)
|
|
2381
|
+
X_train_fit = _to_backend_float64(X_train, fold_backend)
|
|
2382
|
+
y_train_fit = _to_backend_float64(y_train, fold_backend)
|
|
2383
|
+
sw_train_fit = _to_backend_float64(sw_train, fold_backend) if sw_train is not None else None
|
|
2384
|
+
else:
|
|
2385
|
+
X_train_fit = X_train
|
|
2386
|
+
y_train_fit = y_train
|
|
2387
|
+
sw_train_fit = sw_train
|
|
2388
|
+
|
|
2389
|
+
# Precompute XtX/Xty for squared-error GPU cache
|
|
2390
|
+
cv_cache, L_np = self._build_cv_cache(
|
|
2391
|
+
loss_name, device_name, X_train, y_train, sw_train
|
|
2392
|
+
)
|
|
2393
|
+
|
|
2394
|
+
model = PenalizedGeneralizedLinearModel(
|
|
2395
|
+
loss=loss_name, penalty=self.penalty, alpha=alpha_sorted[0],
|
|
2396
|
+
l1_ratio=self.l1_ratio, device=cv_device, compute_inference=False,
|
|
2397
|
+
max_iter=max_iter, tol=tol, solver=cv_solver,
|
|
2398
|
+
)
|
|
2399
|
+
if cv_cache is not None:
|
|
2400
|
+
model._cv_cache = cv_cache
|
|
2401
|
+
model._preserve_cv_cache = True
|
|
2402
|
+
if L_np is not None and L_np > 0:
|
|
2403
|
+
model.lipschitz_L = L_np
|
|
2404
|
+
|
|
2405
|
+
# LLA path for SCAD/MCP
|
|
2406
|
+
if use_lla_path_cv:
|
|
2407
|
+
try:
|
|
2408
|
+
model.alpha = float(alpha_sorted[-1])
|
|
2409
|
+
if hasattr(model, "_penalty") and model._penalty is not None:
|
|
2410
|
+
model._penalty.alpha = float(alpha_sorted[-1])
|
|
2411
|
+
model._cv_alpha_path = np.asarray(alpha_sorted, dtype=np.float64)
|
|
2412
|
+
model.fit(X_train_fit, y_train_fit, sample_weight=sw_train_fit)
|
|
2413
|
+
path = getattr(model, "_cv_path_results", None)
|
|
2414
|
+
if path is not None:
|
|
2415
|
+
path_alphas = np.asarray(path["alpha"], dtype=np.float64)
|
|
2416
|
+
path_coefs = np.asarray(path["coef"], dtype=np.float64)
|
|
2417
|
+
path_intercepts = np.asarray(path["intercept"], dtype=np.float64)
|
|
2418
|
+
for alpha_idx_sorted, alpha in enumerate(alpha_sorted):
|
|
2419
|
+
matches = np.flatnonzero(np.isclose(path_alphas, float(alpha), rtol=1e-10, atol=1e-14))
|
|
2420
|
+
if matches.size == 0:
|
|
2421
|
+
continue
|
|
2422
|
+
path_idx = int(matches[-1])
|
|
2423
|
+
val_loss = _evaluate_loss_numpy(
|
|
2424
|
+
loss_name, loss_fn, X_val_np, y_val_np,
|
|
2425
|
+
path_coefs[path_idx], float(path_intercepts[path_idx]),
|
|
2426
|
+
True, sample_weight=sw_val,
|
|
2427
|
+
)
|
|
2428
|
+
all_scores[fold_idx, sort_idx[alpha_idx_sorted]] = val_loss
|
|
2429
|
+
for attr in ("_cv_alpha_path", "_cv_path_results", "_cv_cache", "_preserve_cv_cache"):
|
|
2430
|
+
if hasattr(model, attr): delattr(model, attr)
|
|
2431
|
+
return
|
|
2432
|
+
else:
|
|
2433
|
+
# path is None — cleanup LLA state only; _cv_cache and
|
|
2434
|
+
# _preserve_cv_cache are still needed for the warm-start fallback.
|
|
2435
|
+
for attr in ("_cv_alpha_path", "_cv_path_results"):
|
|
2436
|
+
if hasattr(model, attr): delattr(model, attr)
|
|
2437
|
+
except Exception:
|
|
2438
|
+
# Same as path-is-None: keep _cv_cache for warm-start fallback.
|
|
2439
|
+
for attr in ("_cv_alpha_path", "_cv_path_results"):
|
|
2440
|
+
if hasattr(model, attr): delattr(model, attr)
|
|
2441
|
+
|
|
2442
|
+
# Warm-started alpha loop: fit per alpha, collect coefs for batch eval
|
|
2443
|
+
prev_coef = None
|
|
2444
|
+
prev_intercept = None
|
|
2445
|
+
fitted_coefs = [] # (alpha_idx_sorted, coef_np, intercept)
|
|
2446
|
+
for alpha_idx_sorted, alpha in enumerate(alpha_sorted):
|
|
2447
|
+
try:
|
|
2448
|
+
if cv_cache is not None:
|
|
2449
|
+
model._cv_cache = cv_cache
|
|
2450
|
+
model.alpha = alpha
|
|
2451
|
+
if hasattr(model, "_penalty") and model._penalty is not None:
|
|
2452
|
+
model._penalty.alpha = alpha
|
|
2453
|
+
if use_warm_start and prev_coef is not None:
|
|
2454
|
+
model._init_coef = np.asarray(prev_coef, dtype=np.float64)
|
|
2455
|
+
model._init_intercept = prev_intercept
|
|
2456
|
+
else:
|
|
2457
|
+
model._init_coef = None
|
|
2458
|
+
model._init_intercept = None
|
|
2459
|
+
model.fit(X_train_fit, y_train_fit, sample_weight=sw_train_fit)
|
|
2460
|
+
|
|
2461
|
+
coef_np = _to_numpy(model.coef_).ravel()
|
|
2462
|
+
intercept = float(model.intercept_)
|
|
2463
|
+
fitted_coefs.append((alpha_idx_sorted, coef_np.copy(), intercept))
|
|
2464
|
+
prev_coef = coef_np.copy()
|
|
2465
|
+
prev_intercept = intercept
|
|
2466
|
+
except Exception as exc:
|
|
2467
|
+
orig_idx = sort_idx[alpha_idx_sorted]
|
|
2468
|
+
all_scores[fold_idx, orig_idx] = np.nan
|
|
2469
|
+
logger.warning(
|
|
2470
|
+
"CV fold %d, alpha_idx %d (alpha=%.6g) fit failed: %s",
|
|
2471
|
+
fold_idx, orig_idx, alpha_sorted[alpha_idx_sorted], exc,
|
|
2472
|
+
)
|
|
2473
|
+
|
|
2474
|
+
# Batch validation: one GEMM for all fitted alphas
|
|
2475
|
+
# Pre-build loss-specific params
|
|
2476
|
+
_loss_params = {}
|
|
2477
|
+
if loss_name == "negative_binomial":
|
|
2478
|
+
_loss_params["alpha"] = float(getattr(loss_fn, "alpha", _NB_ALPHA_DEFAULT))
|
|
2479
|
+
elif loss_name == "tweedie":
|
|
2480
|
+
_loss_params["power"] = float(getattr(loss_fn, "power", _TWEEDIE_POWER_DEFAULT))
|
|
2481
|
+
|
|
2482
|
+
if fitted_coefs:
|
|
2483
|
+
idxs = np.array([fc[0] for fc in fitted_coefs])
|
|
2484
|
+
coef_mat = np.column_stack([fc[1] for fc in fitted_coefs]) # (n_features, n_fitted)
|
|
2485
|
+
intercepts = np.array([fc[2] for fc in fitted_coefs]) # (n_fitted,)
|
|
2486
|
+
eta_mat = X_val_np @ coef_mat + intercepts[np.newaxis, :] # (n_val, n_fitted)
|
|
2487
|
+
|
|
2488
|
+
# Evaluate loss per alpha
|
|
2489
|
+
sw = np.asarray(_to_numpy(sw_val), dtype=np.float64).ravel() if sw_val is not None else None
|
|
2490
|
+
per_sample_loss = None
|
|
2491
|
+
|
|
2492
|
+
if loss_name == "squared_error":
|
|
2493
|
+
# Direct batch computation: squared residual
|
|
2494
|
+
per_sample_loss = (y_val_np[:, np.newaxis] - eta_mat) ** 2
|
|
2495
|
+
else:
|
|
2496
|
+
# GLM losses: use registry
|
|
2497
|
+
entry = _LOSS_EVAL_DISPATCH.get(loss_name)
|
|
2498
|
+
if entry is not None:
|
|
2499
|
+
per_sample_fn, _ = entry
|
|
2500
|
+
per_sample_loss = per_sample_fn(eta_mat, y_val_np[:, np.newaxis], **_loss_params)
|
|
2501
|
+
|
|
2502
|
+
if per_sample_loss is not None:
|
|
2503
|
+
if sw is not None:
|
|
2504
|
+
w_sum = float(np.sum(sw))
|
|
2505
|
+
if w_sum > 0:
|
|
2506
|
+
scores_fitted = np.sum(sw[:, np.newaxis] * per_sample_loss, axis=0) / w_sum
|
|
2507
|
+
else:
|
|
2508
|
+
scores_fitted = np.mean(per_sample_loss, axis=0)
|
|
2509
|
+
else:
|
|
2510
|
+
scores_fitted = np.mean(per_sample_loss, axis=0)
|
|
2511
|
+
for i, alpha_idx_sorted in enumerate(idxs):
|
|
2512
|
+
orig_idx = sort_idx[alpha_idx_sorted]
|
|
2513
|
+
all_scores[fold_idx, orig_idx] = float(scores_fitted[i])
|
|
2514
|
+
|
|
2515
|
+
if hasattr(model, "_cv_cache"): del model._cv_cache
|
|
2516
|
+
if hasattr(model, "_preserve_cv_cache"): del model._preserve_cv_cache
|
|
2517
|
+
|
|
2518
|
+
def _build_cv_cache(self, loss_name, device_name, X_train, y_train, sw_train):
|
|
2519
|
+
"""Precompute XtX/Xty for squared-error GPU cache. Returns (cache_dict, L_np)."""
|
|
2520
|
+
if loss_name != "squared_error" or device_name not in ("cuda", "torch"):
|
|
2521
|
+
return None, None
|
|
2522
|
+
X_train_np = _to_numpy(X_train).astype(np.float64)
|
|
2523
|
+
y_train_np = _to_numpy(y_train).astype(np.float64).ravel()
|
|
2524
|
+
n_tr, _ = X_train_np.shape
|
|
2525
|
+
sw_np = _to_numpy(sw_train).astype(np.float64).ravel() if sw_train is not None else None
|
|
2526
|
+
if sw_np is not None:
|
|
2527
|
+
w_sum = float(sw_np.sum())
|
|
2528
|
+
X_mean_np = np.average(X_train_np, axis=0, weights=sw_np)
|
|
2529
|
+
y_mean_np = float(np.average(y_train_np, weights=sw_np))
|
|
2530
|
+
Xc_np = X_train_np - X_mean_np
|
|
2531
|
+
yc_np = y_train_np - y_mean_np
|
|
2532
|
+
sqrt_w = np.sqrt(sw_np)
|
|
2533
|
+
W_Xc = Xc_np * sqrt_w[:, None]
|
|
2534
|
+
XtX_np = W_Xc.T @ W_Xc
|
|
2535
|
+
Xty_np = (Xc_np * sw_np[:, None]).T @ yc_np
|
|
2536
|
+
L_np = float(_max_eigval_power(XtX_np)) / max(w_sum, 1.0)
|
|
2537
|
+
n_effective = w_sum
|
|
2538
|
+
else:
|
|
2539
|
+
X_mean_np = np.mean(X_train_np, axis=0)
|
|
2540
|
+
y_mean_np = np.mean(y_train_np)
|
|
2541
|
+
Xc_np = X_train_np - X_mean_np
|
|
2542
|
+
yc_np = y_train_np - y_mean_np
|
|
2543
|
+
XtX_np = Xc_np.T @ Xc_np
|
|
2544
|
+
Xty_np = Xc_np.T @ yc_np
|
|
2545
|
+
L_np = float(_max_eigval_power(XtX_np)) / n_tr
|
|
2546
|
+
n_effective = float(n_tr)
|
|
2547
|
+
if device_name == "cuda":
|
|
2548
|
+
import cupy as cp
|
|
2549
|
+
cache = {"XtX": cp.asarray(XtX_np), "Xty": cp.asarray(Xty_np), "n_effective": n_effective}
|
|
2550
|
+
else:
|
|
2551
|
+
import torch
|
|
2552
|
+
_torch_dev = "cuda" if torch.cuda.is_available() else "cpu"
|
|
2553
|
+
cache = {"XtX": torch.as_tensor(XtX_np, device=_torch_dev, dtype=torch.float64),
|
|
2554
|
+
"Xty": torch.as_tensor(Xty_np, device=_torch_dev, dtype=torch.float64),
|
|
2555
|
+
"n_effective": n_effective}
|
|
2556
|
+
return cache, L_np
|
|
2557
|
+
|
|
2558
|
+
def fit(self, X, y, sample_weight=None):
|
|
2559
|
+
"""Fit the CV model with optimized strict or explicit two-stage CV."""
|
|
2560
|
+
# Normalize array-like inputs (lists, tuples, etc.) to arrays
|
|
2561
|
+
if not hasattr(X, 'shape'):
|
|
2562
|
+
X = np.asarray(X, dtype=np.float64)
|
|
2563
|
+
if not hasattr(y, 'shape'):
|
|
2564
|
+
y = np.asarray(y, dtype=np.float64)
|
|
2565
|
+
|
|
2566
|
+
if self._alpha_grid_input is not None:
|
|
2567
|
+
alpha_grid = np.asarray(self._alpha_grid_input, dtype=np.float64)
|
|
2568
|
+
else:
|
|
2569
|
+
alpha_grid = self._generate_alpha_grid(X, y)
|
|
2570
|
+
alpha_grid = np.asarray(alpha_grid, dtype=np.float64).ravel()
|
|
2571
|
+
|
|
2572
|
+
self.alpha_grid_ = alpha_grid
|
|
2573
|
+
n_samples = X.shape[0]
|
|
2574
|
+
n_alphas = len(alpha_grid)
|
|
2575
|
+
penalty_name = str(self.penalty).lower()
|
|
2576
|
+
cv_device = self._effective_cv_device(X, penalty_name, n_alphas)
|
|
2577
|
+
cv_solver = self._solver_for_cv(cv_device, X=X)
|
|
2578
|
+
self.cv_strategy_ = self.cv_strategy
|
|
2579
|
+
self.cv_selected_device_ = _device_to_name(cv_device)
|
|
2580
|
+
|
|
2581
|
+
if self.cv_splits is not None:
|
|
2582
|
+
# Normalize to list (generators would exhaust on first pass)
|
|
2583
|
+
folds = list(self.cv_splits) if not isinstance(self.cv_splits, list) else self.cv_splits
|
|
2584
|
+
else:
|
|
2585
|
+
folds = kfold_indices(n_samples, self.cv, self.random_state)
|
|
2586
|
+
all_scores_stage1 = None
|
|
2587
|
+
mean_scores_stage1 = None
|
|
2588
|
+
refined_mask = np.ones(n_alphas, dtype=bool)
|
|
2589
|
+
|
|
2590
|
+
if self.cv_strategy == "two_stage":
|
|
2591
|
+
if not self.acknowledge_approx:
|
|
2592
|
+
warnings.warn(
|
|
2593
|
+
"PenalizedGLM_CV(cv_strategy='two_stage') uses relaxed CV "
|
|
2594
|
+
"solves to screen the alpha grid before strict refinement. "
|
|
2595
|
+
"The final refit still uses the original max_iter and tol. "
|
|
2596
|
+
"Pass acknowledge_approx=True to silence this warning.",
|
|
2597
|
+
ApproximateCVWarning,
|
|
2598
|
+
stacklevel=2,
|
|
2599
|
+
)
|
|
2600
|
+
stage1_max_iter = min(int(self.max_iter), max(50, int(self.max_iter) // 4))
|
|
2601
|
+
stage1_tol = max(float(self.tol) * 10.0, 1e-4)
|
|
2602
|
+
all_scores_stage1 = self._compute_cv_scores(
|
|
2603
|
+
X,
|
|
2604
|
+
y,
|
|
2605
|
+
alpha_grid,
|
|
2606
|
+
cv_device,
|
|
2607
|
+
folds,
|
|
2608
|
+
sample_weight=sample_weight,
|
|
2609
|
+
max_iter=stage1_max_iter,
|
|
2610
|
+
tol=stage1_tol,
|
|
2611
|
+
strict=False,
|
|
2612
|
+
)
|
|
2613
|
+
mean_scores_stage1 = np.nanmean(all_scores_stage1, axis=0)
|
|
2614
|
+
refined_mask = _two_stage_candidate_mask(
|
|
2615
|
+
mean_scores_stage1,
|
|
2616
|
+
refine_top_k=self.refine_top_k,
|
|
2617
|
+
)
|
|
2618
|
+
if self.loss == "squared_error" and penalty_name in ("scad", "mcp"):
|
|
2619
|
+
refined_mask[:] = True
|
|
2620
|
+
if not np.any(refined_mask):
|
|
2621
|
+
refined_mask[:] = True
|
|
2622
|
+
|
|
2623
|
+
refined_alpha_grid = alpha_grid[refined_mask]
|
|
2624
|
+
refined_scores = self._compute_cv_scores(
|
|
2625
|
+
X,
|
|
2626
|
+
y,
|
|
2627
|
+
refined_alpha_grid,
|
|
2628
|
+
cv_device,
|
|
2629
|
+
folds,
|
|
2630
|
+
sample_weight=sample_weight,
|
|
2631
|
+
max_iter=self.max_iter,
|
|
2632
|
+
tol=self.tol,
|
|
2633
|
+
strict=True,
|
|
2634
|
+
)
|
|
2635
|
+
all_scores = np.array(all_scores_stage1, copy=True)
|
|
2636
|
+
all_scores[:, refined_mask] = refined_scores
|
|
2637
|
+
mean_scores = np.nanmean(all_scores, axis=0)
|
|
2638
|
+
refined_mean = np.nanmean(refined_scores, axis=0)
|
|
2639
|
+
refined_best = self._best_index_from_scores(
|
|
2640
|
+
refined_mean,
|
|
2641
|
+
refined_alpha_grid,
|
|
2642
|
+
cv_solver,
|
|
2643
|
+
)
|
|
2644
|
+
best_idx = int(np.flatnonzero(refined_mask)[refined_best])
|
|
2645
|
+
else:
|
|
2646
|
+
all_scores = self._compute_cv_scores(
|
|
2647
|
+
X,
|
|
2648
|
+
y,
|
|
2649
|
+
alpha_grid,
|
|
2650
|
+
cv_device,
|
|
2651
|
+
folds,
|
|
2652
|
+
sample_weight=sample_weight,
|
|
2653
|
+
max_iter=self.max_iter,
|
|
2654
|
+
tol=self.tol,
|
|
2655
|
+
strict=True,
|
|
2656
|
+
)
|
|
2657
|
+
mean_scores = np.nanmean(all_scores, axis=0)
|
|
2658
|
+
best_idx = self._best_index_from_scores(mean_scores, alpha_grid, cv_solver)
|
|
2659
|
+
|
|
2660
|
+
best_alpha = float(alpha_grid[best_idx])
|
|
2661
|
+
self.alpha_ = best_alpha
|
|
2662
|
+
# sklearn convention: best_score_ is negative loss (higher is better)
|
|
2663
|
+
self.best_score_ = -float(mean_scores[best_idx])
|
|
2664
|
+
self.cv_results_ = {
|
|
2665
|
+
"alpha": alpha_grid,
|
|
2666
|
+
"mean_score": mean_scores,
|
|
2667
|
+
"all_scores": all_scores,
|
|
2668
|
+
"cv_strategy_": self.cv_strategy_,
|
|
2669
|
+
"cv_selected_device_": self.cv_selected_device_,
|
|
2670
|
+
"mean_score_stage1": mean_scores_stage1,
|
|
2671
|
+
"all_scores_stage1": all_scores_stage1,
|
|
2672
|
+
"refined_mask": refined_mask,
|
|
2673
|
+
}
|
|
2674
|
+
|
|
2675
|
+
self.estimator_ = self._refit_best(X, y, best_alpha, sample_weight=sample_weight)
|
|
2676
|
+
self.coef_ = self.estimator_.coef_
|
|
2677
|
+
self.intercept_ = self.estimator_.intercept_
|
|
2678
|
+
|
|
2679
|
+
self._fitted = True
|
|
2680
|
+
return self
|
|
2681
|
+
|
|
2682
|
+
def predict(self, X):
|
|
2683
|
+
"""Predict using the refit estimator with the best alpha."""
|
|
2684
|
+
if not getattr(self, '_fitted', False):
|
|
2685
|
+
raise RuntimeError("PenalizedGLM_CV is not fitted yet. Call fit() first.")
|
|
2686
|
+
return self.estimator_.predict(X)
|
|
2687
|
+
|
|
2688
|
+
def score(self, X, y, sample_weight=None):
|
|
2689
|
+
"""Return the score on the given data.
|
|
2690
|
+
|
|
2691
|
+
For squared_error loss, returns R². For GLM losses, returns
|
|
2692
|
+
the deviance-based pseudo-R² (1 - deviance/null_deviance).
|
|
2693
|
+
|
|
2694
|
+
Note: ``best_score_`` is negative CV loss (sklearn convention),
|
|
2695
|
+
while ``score()`` returns R² or accuracy. These are different metrics.
|
|
2696
|
+
"""
|
|
2697
|
+
if not getattr(self, '_fitted', False):
|
|
2698
|
+
raise RuntimeError("PenalizedGLM_CV is not fitted yet. Call fit() first.")
|
|
2699
|
+
return self.estimator_.score(X, y, sample_weight=sample_weight)
|