statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared base class and utilities for cross-validated estimators.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["CVEstimatorBase", "folds_are_complete", "INTERCEPT_CLIP_BOUND"]
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from statgpu._base import BaseEstimator
|
|
16
|
+
|
|
17
|
+
# Shared constant: intercept clipping bound for CV proximal operators
|
|
18
|
+
INTERCEPT_CLIP_BOUND = 15.0
|
|
19
|
+
from statgpu._config import Device
|
|
20
|
+
from statgpu.backends import _to_numpy
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _torch_cuda_available():
|
|
24
|
+
"""Check if torch CUDA is available (shared utility)."""
|
|
25
|
+
try:
|
|
26
|
+
import torch
|
|
27
|
+
return torch.cuda.is_available()
|
|
28
|
+
except Exception:
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
# K-fold splitting
|
|
34
|
+
# ---------------------------------------------------------------------------
|
|
35
|
+
|
|
36
|
+
def kfold_indices(
|
|
37
|
+
n_samples: int,
|
|
38
|
+
n_splits: int = 5,
|
|
39
|
+
random_state: Optional[int] = None,
|
|
40
|
+
shuffle: bool = True,
|
|
41
|
+
) -> List[Tuple[np.ndarray, np.ndarray]]:
|
|
42
|
+
"""Generate K-fold train/validation index pairs.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
n_samples : int
|
|
47
|
+
Total number of samples.
|
|
48
|
+
n_splits : int
|
|
49
|
+
Number of folds.
|
|
50
|
+
random_state : int or None
|
|
51
|
+
Random seed for reproducibility.
|
|
52
|
+
shuffle : bool
|
|
53
|
+
Whether to shuffle indices before splitting.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
folds : list of (train_idx, val_idx) tuples
|
|
58
|
+
"""
|
|
59
|
+
if n_splits < 2:
|
|
60
|
+
raise ValueError(f"n_splits={n_splits} must be at least 2")
|
|
61
|
+
if n_splits > n_samples:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"n_splits={n_splits} cannot be greater than n_samples={n_samples}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
indices = np.arange(n_samples)
|
|
67
|
+
if shuffle:
|
|
68
|
+
rng = np.random.default_rng(random_state)
|
|
69
|
+
rng.shuffle(indices)
|
|
70
|
+
|
|
71
|
+
fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=int)
|
|
72
|
+
fold_sizes[: n_samples % n_splits] += 1
|
|
73
|
+
|
|
74
|
+
folds = []
|
|
75
|
+
current = 0
|
|
76
|
+
for size in fold_sizes:
|
|
77
|
+
val_idx = indices[current : current + size]
|
|
78
|
+
train_idx = np.concatenate([indices[:current], indices[current + size:]])
|
|
79
|
+
folds.append((train_idx, val_idx))
|
|
80
|
+
current += size
|
|
81
|
+
|
|
82
|
+
return folds
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def folds_are_complete(folds, n_samples: int) -> bool:
|
|
86
|
+
"""Check that all folds together cover every sample exactly once."""
|
|
87
|
+
val_indices = np.concatenate([f[1] for f in folds])
|
|
88
|
+
if len(val_indices) != n_samples:
|
|
89
|
+
return False
|
|
90
|
+
return np.array_equal(np.sort(val_indices), np.arange(n_samples))
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def hash_cv_data(X, y, sample_weight=None) -> bytes:
|
|
94
|
+
"""Compute a compact hash of X, y, and optionally sample_weight.
|
|
95
|
+
|
|
96
|
+
For small datasets (n * p <= 10,000,000), hashes full content for zero
|
|
97
|
+
collision risk. For very large datasets, samples evenly spaced rows plus
|
|
98
|
+
first/last rows, row indices, and aggregate statistics to keep hashing fast
|
|
99
|
+
while minimizing collision probability.
|
|
100
|
+
"""
|
|
101
|
+
h = hashlib.blake2b(digest_size=16)
|
|
102
|
+
X_np = np.asarray(_to_numpy(X), dtype=np.float64)
|
|
103
|
+
y_np = np.asarray(_to_numpy(y), dtype=np.float64).ravel()
|
|
104
|
+
n, p = X_np.shape
|
|
105
|
+
h.update(np.asarray([n, p], dtype=np.int64).tobytes())
|
|
106
|
+
|
|
107
|
+
_FULL_HASH_THRESHOLD = 10_000_000 # n * p threshold for full hashing
|
|
108
|
+
if n * p <= _FULL_HASH_THRESHOLD:
|
|
109
|
+
# Small dataset: hash full content (zero collision risk)
|
|
110
|
+
h.update(X_np.tobytes())
|
|
111
|
+
h.update(y_np.tobytes())
|
|
112
|
+
if sample_weight is not None:
|
|
113
|
+
sw_np = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
|
|
114
|
+
h.update(sw_np.tobytes())
|
|
115
|
+
else:
|
|
116
|
+
# Very large dataset: sample rows + indices + aggregate statistics
|
|
117
|
+
# Include first and last rows (boundary) plus evenly spaced interior
|
|
118
|
+
step = max(1, n // 100)
|
|
119
|
+
idx = np.arange(0, n, step)[:100]
|
|
120
|
+
# Ensure first and last rows are always included
|
|
121
|
+
if idx[0] != 0:
|
|
122
|
+
idx = np.concatenate([[0], idx])
|
|
123
|
+
if idx[-1] != n - 1:
|
|
124
|
+
idx = np.concatenate([idx, [n - 1]])
|
|
125
|
+
# Hash row indices to prevent collision from reordered data
|
|
126
|
+
h.update(idx.astype(np.int64).tobytes())
|
|
127
|
+
h.update(X_np[idx].tobytes())
|
|
128
|
+
h.update(y_np[idx].tobytes())
|
|
129
|
+
h.update(np.asarray([X_np.mean(), X_np.std()], dtype=np.float64).tobytes())
|
|
130
|
+
h.update(np.asarray([y_np.mean(), y_np.std()], dtype=np.float64).tobytes())
|
|
131
|
+
if sample_weight is not None:
|
|
132
|
+
sw_np = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
|
|
133
|
+
h.update(sw_np[idx].tobytes())
|
|
134
|
+
h.update(np.asarray([sw_np.mean(), sw_np.std()], dtype=np.float64).tobytes())
|
|
135
|
+
return h.digest()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def validate_cv_sample_weight(sample_weight, n_samples: int):
|
|
139
|
+
"""Validate sample_weight for CV: must be non-negative and finite.
|
|
140
|
+
|
|
141
|
+
Returns None if sample_weight is None, otherwise returns validated array.
|
|
142
|
+
Raises ValueError for invalid weights. Preserves the original backend
|
|
143
|
+
(CuPy/Torch/numpy) — does not force conversion to numpy.
|
|
144
|
+
"""
|
|
145
|
+
if sample_weight is None:
|
|
146
|
+
return None
|
|
147
|
+
# Validate on numpy (single D2H sync) but return original array
|
|
148
|
+
sw_np = _to_numpy(sample_weight).ravel().astype(np.float64)
|
|
149
|
+
if sw_np.shape[0] != n_samples:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"sample_weight length {sw_np.shape[0]} != n_samples {n_samples}"
|
|
152
|
+
)
|
|
153
|
+
if np.any(sw_np < 0):
|
|
154
|
+
raise ValueError("sample_weight must be non-negative")
|
|
155
|
+
if not np.all(np.isfinite(sw_np)):
|
|
156
|
+
raise ValueError("sample_weight must be finite")
|
|
157
|
+
# Return the original array (preserves CuPy/Torch backend)
|
|
158
|
+
return sample_weight
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
# LRU cache for CV results
|
|
163
|
+
# ---------------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
class CVCache:
|
|
166
|
+
"""Simple LRU cache for cross-validation results.
|
|
167
|
+
|
|
168
|
+
Thread-safe: all mutations are protected by a lock.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
maxsize : int
|
|
173
|
+
Maximum number of cached entries.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, maxsize: int = 64):
|
|
177
|
+
self._cache: OrderedDict = OrderedDict()
|
|
178
|
+
self._maxsize = maxsize
|
|
179
|
+
self._lock = __import__('threading').Lock()
|
|
180
|
+
|
|
181
|
+
def get(self, key: str):
|
|
182
|
+
"""Retrieve cached result, or None if not found."""
|
|
183
|
+
with self._lock:
|
|
184
|
+
if key in self._cache:
|
|
185
|
+
self._cache.move_to_end(key)
|
|
186
|
+
return self._cache[key]
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
def put(self, key: str, value):
|
|
190
|
+
"""Store a result in the cache."""
|
|
191
|
+
with self._lock:
|
|
192
|
+
self._cache[key] = value
|
|
193
|
+
self._cache.move_to_end(key)
|
|
194
|
+
while len(self._cache) > self._maxsize:
|
|
195
|
+
self._cache.popitem(last=False)
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def make_key(*args) -> str:
|
|
199
|
+
"""Generate a blake2b hash key from arbitrary arguments.
|
|
200
|
+
|
|
201
|
+
Uses content-based hashing for arrays (tobytes) to avoid collisions
|
|
202
|
+
from str() truncation on large arrays.
|
|
203
|
+
"""
|
|
204
|
+
h = hashlib.blake2b(digest_size=32)
|
|
205
|
+
for arg in args:
|
|
206
|
+
if hasattr(arg, 'tobytes') and hasattr(arg, 'shape'):
|
|
207
|
+
# Array-like: hash shape + content bytes
|
|
208
|
+
h.update(str(arg.shape).encode())
|
|
209
|
+
h.update(np.ascontiguousarray(_to_numpy(arg)).tobytes())
|
|
210
|
+
else:
|
|
211
|
+
h.update(str(arg).encode())
|
|
212
|
+
return h.hexdigest()
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ---------------------------------------------------------------------------
|
|
216
|
+
# GPU input detection
|
|
217
|
+
# ---------------------------------------------------------------------------
|
|
218
|
+
|
|
219
|
+
def detect_gpu_input(X, y) -> Tuple[str, Any, Any]:
|
|
220
|
+
"""Detect whether inputs are CuPy or Torch arrays.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
backend : str
|
|
225
|
+
One of 'numpy', 'cupy', 'torch'.
|
|
226
|
+
X, y : arrays
|
|
227
|
+
Original arrays (unchanged).
|
|
228
|
+
"""
|
|
229
|
+
import warnings as _warnings
|
|
230
|
+
|
|
231
|
+
x_type = None
|
|
232
|
+
y_type = None
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
import cupy as cp
|
|
236
|
+
if isinstance(X, cp.ndarray):
|
|
237
|
+
x_type = 'cupy'
|
|
238
|
+
if isinstance(y, cp.ndarray):
|
|
239
|
+
y_type = 'cupy'
|
|
240
|
+
except ImportError:
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
import torch
|
|
245
|
+
if isinstance(X, torch.Tensor):
|
|
246
|
+
x_type = 'torch'
|
|
247
|
+
if isinstance(y, torch.Tensor):
|
|
248
|
+
y_type = 'torch'
|
|
249
|
+
except ImportError:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
if x_type is not None and y_type is not None and x_type != y_type:
|
|
253
|
+
_warnings.warn(
|
|
254
|
+
f"Mixed backend detected: X is {x_type} but y is {y_type}. "
|
|
255
|
+
f"Both arrays should use the same backend. Falling back to numpy.",
|
|
256
|
+
RuntimeWarning,
|
|
257
|
+
stacklevel=2,
|
|
258
|
+
)
|
|
259
|
+
# Convert both arrays to numpy for consistent backend
|
|
260
|
+
X_np = _to_numpy(X)
|
|
261
|
+
y_np = _to_numpy(y)
|
|
262
|
+
return 'numpy', X_np, y_np
|
|
263
|
+
|
|
264
|
+
if x_type == 'cupy' and y_type == 'cupy':
|
|
265
|
+
return 'cupy', X, y
|
|
266
|
+
if x_type == 'torch' and y_type == 'torch':
|
|
267
|
+
return 'torch', X, y
|
|
268
|
+
|
|
269
|
+
return 'numpy', X, y
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# ---------------------------------------------------------------------------
|
|
273
|
+
# Batch MSE computation
|
|
274
|
+
# ---------------------------------------------------------------------------
|
|
275
|
+
|
|
276
|
+
def batch_mse(
|
|
277
|
+
X_val,
|
|
278
|
+
y_val,
|
|
279
|
+
coefs: np.ndarray,
|
|
280
|
+
intercepts: Optional[np.ndarray] = None,
|
|
281
|
+
sample_weight=None,
|
|
282
|
+
chunk_size: int = 256,
|
|
283
|
+
) -> np.ndarray:
|
|
284
|
+
"""Compute MSE for multiple coefficient vectors on a validation set.
|
|
285
|
+
|
|
286
|
+
Processes models in chunks to limit peak memory to
|
|
287
|
+
O(chunk_size * n_val) instead of O(n_models * n_val).
|
|
288
|
+
|
|
289
|
+
Parameters
|
|
290
|
+
----------
|
|
291
|
+
X_val : array, shape (n_val, n_features)
|
|
292
|
+
y_val : array, shape (n_val,)
|
|
293
|
+
coefs : array, shape (n_models, n_features)
|
|
294
|
+
intercepts : array, shape (n_models,) or None
|
|
295
|
+
sample_weight : array, shape (n_val,) or None
|
|
296
|
+
chunk_size : int
|
|
297
|
+
Number of models to process at once (default 256).
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
mse : array, shape (n_models,)
|
|
302
|
+
"""
|
|
303
|
+
X_val = _to_numpy(X_val)
|
|
304
|
+
y_val = _to_numpy(y_val).ravel()
|
|
305
|
+
coefs = _to_numpy(coefs)
|
|
306
|
+
|
|
307
|
+
# Validate dimensions
|
|
308
|
+
if coefs.ndim != 2:
|
|
309
|
+
raise ValueError(f"coefs must be 2D (n_models, n_features), got shape {coefs.shape}")
|
|
310
|
+
if X_val.ndim != 2:
|
|
311
|
+
raise ValueError(f"X_val must be 2D (n_samples, n_features), got shape {X_val.shape}")
|
|
312
|
+
if coefs.shape[1] != X_val.shape[1]:
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"Feature dimension mismatch: coefs has {coefs.shape[1]} features, "
|
|
315
|
+
f"X_val has {X_val.shape[1]} features"
|
|
316
|
+
)
|
|
317
|
+
if y_val.shape[0] != X_val.shape[0]:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
f"Sample count mismatch: y has {y_val.shape[0]} samples, "
|
|
320
|
+
f"X_val has {X_val.shape[0]} samples"
|
|
321
|
+
)
|
|
322
|
+
n_models = coefs.shape[0]
|
|
323
|
+
|
|
324
|
+
if intercepts is not None:
|
|
325
|
+
intercepts = _to_numpy(intercepts)
|
|
326
|
+
|
|
327
|
+
if sample_weight is not None:
|
|
328
|
+
sw = _to_numpy(sample_weight).ravel()
|
|
329
|
+
sw_sum = float(np.sum(sw))
|
|
330
|
+
else:
|
|
331
|
+
sw = None
|
|
332
|
+
sw_sum = 0.0
|
|
333
|
+
|
|
334
|
+
mse = np.empty(n_models, dtype=np.float64)
|
|
335
|
+
|
|
336
|
+
# Process in chunks to limit peak memory
|
|
337
|
+
for start in range(0, n_models, chunk_size):
|
|
338
|
+
end = min(start + chunk_size, n_models)
|
|
339
|
+
coefs_chunk = coefs[start:end]
|
|
340
|
+
|
|
341
|
+
# y_pred shape: (chunk_size, n_val)
|
|
342
|
+
y_pred = coefs_chunk @ X_val.T
|
|
343
|
+
if intercepts is not None:
|
|
344
|
+
y_pred = y_pred + intercepts[start:end, None]
|
|
345
|
+
|
|
346
|
+
residuals = y_val[None, :] - y_pred # (chunk_size, n_val)
|
|
347
|
+
|
|
348
|
+
if sw is not None:
|
|
349
|
+
if sw_sum > 0:
|
|
350
|
+
mse[start:end] = np.sum(residuals ** 2 * sw[None, :], axis=1) / sw_sum
|
|
351
|
+
else:
|
|
352
|
+
mse[start:end] = np.nan
|
|
353
|
+
else:
|
|
354
|
+
mse[start:end] = np.mean(residuals ** 2, axis=1)
|
|
355
|
+
|
|
356
|
+
return mse
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
# ---------------------------------------------------------------------------
|
|
360
|
+
# Base class
|
|
361
|
+
# ---------------------------------------------------------------------------
|
|
362
|
+
|
|
363
|
+
class CVEstimatorBase(BaseEstimator):
|
|
364
|
+
"""
|
|
365
|
+
Common scaffolding for model-specific CV estimators.
|
|
366
|
+
|
|
367
|
+
This is intentionally lightweight: each model keeps its own CV search
|
|
368
|
+
routine and fitted attributes, while shared plumbing lives here.
|
|
369
|
+
"""
|
|
370
|
+
|
|
371
|
+
def __init__(
|
|
372
|
+
self,
|
|
373
|
+
*,
|
|
374
|
+
cv: int = 5,
|
|
375
|
+
random_state: Optional[int] = None,
|
|
376
|
+
device: Union[str, Device] = Device.AUTO,
|
|
377
|
+
n_jobs: Optional[int] = None,
|
|
378
|
+
):
|
|
379
|
+
super().__init__(device=device, n_jobs=n_jobs)
|
|
380
|
+
self.cv = int(cv)
|
|
381
|
+
if self.cv < 2:
|
|
382
|
+
raise ValueError(f"cv must be >= 2, got {self.cv}")
|
|
383
|
+
self.random_state = random_state
|
|
384
|
+
|
|
385
|
+
# Common fitted attributes for CV estimators.
|
|
386
|
+
self.best_score_ = None
|
|
387
|
+
self.cv_results_ = None
|
|
388
|
+
self.estimator_ = None
|
|
389
|
+
|
|
390
|
+
def predict(self, X):
|
|
391
|
+
self._check_is_fitted()
|
|
392
|
+
if self.estimator_ is None:
|
|
393
|
+
raise RuntimeError("No fitted base estimator is available.")
|
|
394
|
+
return self.estimator_.predict(X)
|
|
395
|
+
|
|
396
|
+
def score(self, X, y):
|
|
397
|
+
self._check_is_fitted()
|
|
398
|
+
if self.estimator_ is None:
|
|
399
|
+
raise RuntimeError("No fitted base estimator is available.")
|
|
400
|
+
return self.estimator_.score(X, y)
|
|
401
|
+
|
|
402
|
+
def summary(self):
|
|
403
|
+
self._check_is_fitted()
|
|
404
|
+
if self.estimator_ is None:
|
|
405
|
+
raise RuntimeError("No fitted base estimator is available.")
|
|
406
|
+
if not hasattr(self.estimator_, "summary"):
|
|
407
|
+
raise RuntimeError(
|
|
408
|
+
f"{self.estimator_.__class__.__name__} does not implement summary()."
|
|
409
|
+
)
|
|
410
|
+
return self.estimator_.summary()
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generic cross-validation engine for penalized GLM models.
|
|
3
|
+
|
|
4
|
+
Provides a reusable CV loop that can be parameterized by:
|
|
5
|
+
- Any loss function (squared_error, logistic, poisson, etc.)
|
|
6
|
+
- Any penalty type (l1, l2, elasticnet, scad, mcp, etc.)
|
|
7
|
+
- Any backend (numpy, cupy, torch)
|
|
8
|
+
|
|
9
|
+
.. note::
|
|
10
|
+
|
|
11
|
+
**Reference Implementation**: ``run_cv`` is a simple, readable reference
|
|
12
|
+
implementation intended for:
|
|
13
|
+
- Custom estimators that need a basic CV loop
|
|
14
|
+
- Testing and prototyping new CV strategies
|
|
15
|
+
- Documentation of the CV algorithm
|
|
16
|
+
|
|
17
|
+
The production CV paths (PenalizedGLM_CV, LassoCV, RidgeCV, etc.) use
|
|
18
|
+
their own optimized loops with warm-starting, fold batching, and
|
|
19
|
+
backend-specific optimizations. For production use, prefer those
|
|
20
|
+
estimators directly.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
__all__ = ["run_cv"]
|
|
26
|
+
|
|
27
|
+
import logging
|
|
28
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
from statgpu.cross_validation._base import (
|
|
33
|
+
CVCache,
|
|
34
|
+
kfold_indices,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def run_cv(
|
|
41
|
+
X,
|
|
42
|
+
y,
|
|
43
|
+
alpha_grid: np.ndarray,
|
|
44
|
+
evaluate_fold_fn: Callable,
|
|
45
|
+
n_folds: int = 5,
|
|
46
|
+
random_state: Optional[int] = None,
|
|
47
|
+
minimize: bool = True,
|
|
48
|
+
cache: Optional[CVCache] = None,
|
|
49
|
+
cache_key_fn: Optional[Callable] = None,
|
|
50
|
+
sample_weight=None,
|
|
51
|
+
raise_on_error: bool = False,
|
|
52
|
+
) -> Tuple[float, np.ndarray, np.ndarray]:
|
|
53
|
+
"""Execute K-fold cross-validation.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
X : array, shape (n_samples, n_features)
|
|
58
|
+
Feature matrix.
|
|
59
|
+
y : array, shape (n_samples,)
|
|
60
|
+
Target vector.
|
|
61
|
+
alpha_grid : array, shape (n_alphas,)
|
|
62
|
+
Regularization parameter grid.
|
|
63
|
+
evaluate_fold_fn : callable
|
|
64
|
+
Function ``(X_train, y_train, X_val, y_val, alpha,
|
|
65
|
+
sample_weight_train=None, sample_weight_val=None) -> score``
|
|
66
|
+
that trains on the training fold and returns a scalar score on
|
|
67
|
+
the validation fold.
|
|
68
|
+
n_folds : int
|
|
69
|
+
Number of CV folds.
|
|
70
|
+
random_state : int or None
|
|
71
|
+
Random seed for fold generation.
|
|
72
|
+
minimize : bool
|
|
73
|
+
If True, lower score is better. If False, higher score is better.
|
|
74
|
+
cache : CVCache or None
|
|
75
|
+
Optional LRU cache for CV results.
|
|
76
|
+
cache_key_fn : callable or None
|
|
77
|
+
Function ``(X, y, alpha_grid, folds) -> str`` for cache key.
|
|
78
|
+
sample_weight : array or None
|
|
79
|
+
Optional sample weights (passed through to evaluate_fold_fn).
|
|
80
|
+
raise_on_error : bool, default False
|
|
81
|
+
If True, re-raise exceptions from evaluate_fold_fn instead of
|
|
82
|
+
logging a warning and setting the score to NaN.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
best_alpha : float
|
|
87
|
+
Alpha value that optimizes the CV score.
|
|
88
|
+
mean_scores : array, shape (n_alphas,)
|
|
89
|
+
Mean CV score for each alpha.
|
|
90
|
+
all_scores : array, shape (n_folds, n_alphas,)
|
|
91
|
+
Per-fold CV scores.
|
|
92
|
+
"""
|
|
93
|
+
# 0. Validate inputs
|
|
94
|
+
n_samples = X.shape[0]
|
|
95
|
+
if y.shape[0] != n_samples:
|
|
96
|
+
raise ValueError(f"X and y have different number of samples: {n_samples} vs {y.shape[0]}")
|
|
97
|
+
if len(alpha_grid) == 0:
|
|
98
|
+
raise ValueError("alpha_grid must not be empty")
|
|
99
|
+
if sample_weight is not None and len(sample_weight) != n_samples:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"sample_weight length {len(sample_weight)} != n_samples {n_samples}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# 1. Generate folds
|
|
105
|
+
folds = kfold_indices(n_samples, n_folds, random_state)
|
|
106
|
+
|
|
107
|
+
# 2. Check cache
|
|
108
|
+
cache_key = None
|
|
109
|
+
if cache is not None and cache_key_fn is not None:
|
|
110
|
+
cache_key = cache_key_fn(X, y, alpha_grid, folds)
|
|
111
|
+
cached = cache.get(cache_key)
|
|
112
|
+
if cached is not None:
|
|
113
|
+
return cached
|
|
114
|
+
|
|
115
|
+
# 3. Evaluate each (fold, alpha) pair
|
|
116
|
+
n_alphas = len(alpha_grid)
|
|
117
|
+
all_scores = np.full((n_folds, n_alphas), np.nan)
|
|
118
|
+
|
|
119
|
+
for fold_idx, (train_idx, val_idx) in enumerate(folds):
|
|
120
|
+
X_train = X[train_idx]
|
|
121
|
+
y_train = y[train_idx]
|
|
122
|
+
X_val = X[val_idx]
|
|
123
|
+
y_val = y[val_idx]
|
|
124
|
+
|
|
125
|
+
sw_train = sample_weight[train_idx] if sample_weight is not None else None
|
|
126
|
+
sw_val = sample_weight[val_idx] if sample_weight is not None else None
|
|
127
|
+
|
|
128
|
+
for alpha_idx, alpha in enumerate(alpha_grid):
|
|
129
|
+
try:
|
|
130
|
+
score = evaluate_fold_fn(
|
|
131
|
+
X_train, y_train, X_val, y_val, alpha,
|
|
132
|
+
sample_weight_train=sw_train,
|
|
133
|
+
sample_weight_val=sw_val,
|
|
134
|
+
)
|
|
135
|
+
all_scores[fold_idx, alpha_idx] = score
|
|
136
|
+
except (ValueError, FloatingPointError, np.linalg.LinAlgError, RuntimeError) as exc:
|
|
137
|
+
if raise_on_error:
|
|
138
|
+
raise
|
|
139
|
+
all_scores[fold_idx, alpha_idx] = np.nan
|
|
140
|
+
logger.warning(
|
|
141
|
+
"CV fold %d, alpha_idx %d failed: %s",
|
|
142
|
+
fold_idx, alpha_idx, exc,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# 4. Aggregate across folds
|
|
146
|
+
mean_scores = np.nanmean(all_scores, axis=0)
|
|
147
|
+
|
|
148
|
+
# Guard against all-NaN slices (all folds failed for every alpha)
|
|
149
|
+
finite_mask = np.isfinite(mean_scores)
|
|
150
|
+
if not np.any(finite_mask):
|
|
151
|
+
raise ValueError(
|
|
152
|
+
"All CV scores are NaN — every fold failed for every alpha. "
|
|
153
|
+
"Check for data issues or increase max_iter."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if minimize:
|
|
157
|
+
best_idx = int(np.nanargmin(mean_scores))
|
|
158
|
+
else:
|
|
159
|
+
best_idx = int(np.nanargmax(mean_scores))
|
|
160
|
+
|
|
161
|
+
best_alpha = float(alpha_grid[best_idx])
|
|
162
|
+
|
|
163
|
+
# 5. Cache results (copy arrays to prevent mutation corruption)
|
|
164
|
+
if cache is not None and cache_key_fn is not None:
|
|
165
|
+
cache.put(cache_key, (best_alpha, mean_scores.copy(), all_scores.copy()))
|
|
166
|
+
|
|
167
|
+
return best_alpha, mean_scores, all_scores
|