statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,410 @@
1
+ """
2
+ Shared base class and utilities for cross-validated estimators.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["CVEstimatorBase", "folds_are_complete", "INTERCEPT_CLIP_BOUND"]
8
+
9
+ import hashlib
10
+ from collections import OrderedDict
11
+ from typing import Any, Dict, List, Optional, Tuple, Union
12
+
13
+ import numpy as np
14
+
15
+ from statgpu._base import BaseEstimator
16
+
17
+ # Shared constant: intercept clipping bound for CV proximal operators
18
+ INTERCEPT_CLIP_BOUND = 15.0
19
+ from statgpu._config import Device
20
+ from statgpu.backends import _to_numpy
21
+
22
+
23
+ def _torch_cuda_available():
24
+ """Check if torch CUDA is available (shared utility)."""
25
+ try:
26
+ import torch
27
+ return torch.cuda.is_available()
28
+ except Exception:
29
+ return False
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # K-fold splitting
34
+ # ---------------------------------------------------------------------------
35
+
36
+ def kfold_indices(
37
+ n_samples: int,
38
+ n_splits: int = 5,
39
+ random_state: Optional[int] = None,
40
+ shuffle: bool = True,
41
+ ) -> List[Tuple[np.ndarray, np.ndarray]]:
42
+ """Generate K-fold train/validation index pairs.
43
+
44
+ Parameters
45
+ ----------
46
+ n_samples : int
47
+ Total number of samples.
48
+ n_splits : int
49
+ Number of folds.
50
+ random_state : int or None
51
+ Random seed for reproducibility.
52
+ shuffle : bool
53
+ Whether to shuffle indices before splitting.
54
+
55
+ Returns
56
+ -------
57
+ folds : list of (train_idx, val_idx) tuples
58
+ """
59
+ if n_splits < 2:
60
+ raise ValueError(f"n_splits={n_splits} must be at least 2")
61
+ if n_splits > n_samples:
62
+ raise ValueError(
63
+ f"n_splits={n_splits} cannot be greater than n_samples={n_samples}"
64
+ )
65
+
66
+ indices = np.arange(n_samples)
67
+ if shuffle:
68
+ rng = np.random.default_rng(random_state)
69
+ rng.shuffle(indices)
70
+
71
+ fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=int)
72
+ fold_sizes[: n_samples % n_splits] += 1
73
+
74
+ folds = []
75
+ current = 0
76
+ for size in fold_sizes:
77
+ val_idx = indices[current : current + size]
78
+ train_idx = np.concatenate([indices[:current], indices[current + size:]])
79
+ folds.append((train_idx, val_idx))
80
+ current += size
81
+
82
+ return folds
83
+
84
+
85
+ def folds_are_complete(folds, n_samples: int) -> bool:
86
+ """Check that all folds together cover every sample exactly once."""
87
+ val_indices = np.concatenate([f[1] for f in folds])
88
+ if len(val_indices) != n_samples:
89
+ return False
90
+ return np.array_equal(np.sort(val_indices), np.arange(n_samples))
91
+
92
+
93
+ def hash_cv_data(X, y, sample_weight=None) -> bytes:
94
+ """Compute a compact hash of X, y, and optionally sample_weight.
95
+
96
+ For small datasets (n * p <= 10,000,000), hashes full content for zero
97
+ collision risk. For very large datasets, samples evenly spaced rows plus
98
+ first/last rows, row indices, and aggregate statistics to keep hashing fast
99
+ while minimizing collision probability.
100
+ """
101
+ h = hashlib.blake2b(digest_size=16)
102
+ X_np = np.asarray(_to_numpy(X), dtype=np.float64)
103
+ y_np = np.asarray(_to_numpy(y), dtype=np.float64).ravel()
104
+ n, p = X_np.shape
105
+ h.update(np.asarray([n, p], dtype=np.int64).tobytes())
106
+
107
+ _FULL_HASH_THRESHOLD = 10_000_000 # n * p threshold for full hashing
108
+ if n * p <= _FULL_HASH_THRESHOLD:
109
+ # Small dataset: hash full content (zero collision risk)
110
+ h.update(X_np.tobytes())
111
+ h.update(y_np.tobytes())
112
+ if sample_weight is not None:
113
+ sw_np = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
114
+ h.update(sw_np.tobytes())
115
+ else:
116
+ # Very large dataset: sample rows + indices + aggregate statistics
117
+ # Include first and last rows (boundary) plus evenly spaced interior
118
+ step = max(1, n // 100)
119
+ idx = np.arange(0, n, step)[:100]
120
+ # Ensure first and last rows are always included
121
+ if idx[0] != 0:
122
+ idx = np.concatenate([[0], idx])
123
+ if idx[-1] != n - 1:
124
+ idx = np.concatenate([idx, [n - 1]])
125
+ # Hash row indices to prevent collision from reordered data
126
+ h.update(idx.astype(np.int64).tobytes())
127
+ h.update(X_np[idx].tobytes())
128
+ h.update(y_np[idx].tobytes())
129
+ h.update(np.asarray([X_np.mean(), X_np.std()], dtype=np.float64).tobytes())
130
+ h.update(np.asarray([y_np.mean(), y_np.std()], dtype=np.float64).tobytes())
131
+ if sample_weight is not None:
132
+ sw_np = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
133
+ h.update(sw_np[idx].tobytes())
134
+ h.update(np.asarray([sw_np.mean(), sw_np.std()], dtype=np.float64).tobytes())
135
+ return h.digest()
136
+
137
+
138
+ def validate_cv_sample_weight(sample_weight, n_samples: int):
139
+ """Validate sample_weight for CV: must be non-negative and finite.
140
+
141
+ Returns None if sample_weight is None, otherwise returns validated array.
142
+ Raises ValueError for invalid weights. Preserves the original backend
143
+ (CuPy/Torch/numpy) — does not force conversion to numpy.
144
+ """
145
+ if sample_weight is None:
146
+ return None
147
+ # Validate on numpy (single D2H sync) but return original array
148
+ sw_np = _to_numpy(sample_weight).ravel().astype(np.float64)
149
+ if sw_np.shape[0] != n_samples:
150
+ raise ValueError(
151
+ f"sample_weight length {sw_np.shape[0]} != n_samples {n_samples}"
152
+ )
153
+ if np.any(sw_np < 0):
154
+ raise ValueError("sample_weight must be non-negative")
155
+ if not np.all(np.isfinite(sw_np)):
156
+ raise ValueError("sample_weight must be finite")
157
+ # Return the original array (preserves CuPy/Torch backend)
158
+ return sample_weight
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # LRU cache for CV results
163
+ # ---------------------------------------------------------------------------
164
+
165
+ class CVCache:
166
+ """Simple LRU cache for cross-validation results.
167
+
168
+ Thread-safe: all mutations are protected by a lock.
169
+
170
+ Parameters
171
+ ----------
172
+ maxsize : int
173
+ Maximum number of cached entries.
174
+ """
175
+
176
+ def __init__(self, maxsize: int = 64):
177
+ self._cache: OrderedDict = OrderedDict()
178
+ self._maxsize = maxsize
179
+ self._lock = __import__('threading').Lock()
180
+
181
+ def get(self, key: str):
182
+ """Retrieve cached result, or None if not found."""
183
+ with self._lock:
184
+ if key in self._cache:
185
+ self._cache.move_to_end(key)
186
+ return self._cache[key]
187
+ return None
188
+
189
+ def put(self, key: str, value):
190
+ """Store a result in the cache."""
191
+ with self._lock:
192
+ self._cache[key] = value
193
+ self._cache.move_to_end(key)
194
+ while len(self._cache) > self._maxsize:
195
+ self._cache.popitem(last=False)
196
+
197
+ @staticmethod
198
+ def make_key(*args) -> str:
199
+ """Generate a blake2b hash key from arbitrary arguments.
200
+
201
+ Uses content-based hashing for arrays (tobytes) to avoid collisions
202
+ from str() truncation on large arrays.
203
+ """
204
+ h = hashlib.blake2b(digest_size=32)
205
+ for arg in args:
206
+ if hasattr(arg, 'tobytes') and hasattr(arg, 'shape'):
207
+ # Array-like: hash shape + content bytes
208
+ h.update(str(arg.shape).encode())
209
+ h.update(np.ascontiguousarray(_to_numpy(arg)).tobytes())
210
+ else:
211
+ h.update(str(arg).encode())
212
+ return h.hexdigest()
213
+
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # GPU input detection
217
+ # ---------------------------------------------------------------------------
218
+
219
+ def detect_gpu_input(X, y) -> Tuple[str, Any, Any]:
220
+ """Detect whether inputs are CuPy or Torch arrays.
221
+
222
+ Returns
223
+ -------
224
+ backend : str
225
+ One of 'numpy', 'cupy', 'torch'.
226
+ X, y : arrays
227
+ Original arrays (unchanged).
228
+ """
229
+ import warnings as _warnings
230
+
231
+ x_type = None
232
+ y_type = None
233
+
234
+ try:
235
+ import cupy as cp
236
+ if isinstance(X, cp.ndarray):
237
+ x_type = 'cupy'
238
+ if isinstance(y, cp.ndarray):
239
+ y_type = 'cupy'
240
+ except ImportError:
241
+ pass
242
+
243
+ try:
244
+ import torch
245
+ if isinstance(X, torch.Tensor):
246
+ x_type = 'torch'
247
+ if isinstance(y, torch.Tensor):
248
+ y_type = 'torch'
249
+ except ImportError:
250
+ pass
251
+
252
+ if x_type is not None and y_type is not None and x_type != y_type:
253
+ _warnings.warn(
254
+ f"Mixed backend detected: X is {x_type} but y is {y_type}. "
255
+ f"Both arrays should use the same backend. Falling back to numpy.",
256
+ RuntimeWarning,
257
+ stacklevel=2,
258
+ )
259
+ # Convert both arrays to numpy for consistent backend
260
+ X_np = _to_numpy(X)
261
+ y_np = _to_numpy(y)
262
+ return 'numpy', X_np, y_np
263
+
264
+ if x_type == 'cupy' and y_type == 'cupy':
265
+ return 'cupy', X, y
266
+ if x_type == 'torch' and y_type == 'torch':
267
+ return 'torch', X, y
268
+
269
+ return 'numpy', X, y
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Batch MSE computation
274
+ # ---------------------------------------------------------------------------
275
+
276
+ def batch_mse(
277
+ X_val,
278
+ y_val,
279
+ coefs: np.ndarray,
280
+ intercepts: Optional[np.ndarray] = None,
281
+ sample_weight=None,
282
+ chunk_size: int = 256,
283
+ ) -> np.ndarray:
284
+ """Compute MSE for multiple coefficient vectors on a validation set.
285
+
286
+ Processes models in chunks to limit peak memory to
287
+ O(chunk_size * n_val) instead of O(n_models * n_val).
288
+
289
+ Parameters
290
+ ----------
291
+ X_val : array, shape (n_val, n_features)
292
+ y_val : array, shape (n_val,)
293
+ coefs : array, shape (n_models, n_features)
294
+ intercepts : array, shape (n_models,) or None
295
+ sample_weight : array, shape (n_val,) or None
296
+ chunk_size : int
297
+ Number of models to process at once (default 256).
298
+
299
+ Returns
300
+ -------
301
+ mse : array, shape (n_models,)
302
+ """
303
+ X_val = _to_numpy(X_val)
304
+ y_val = _to_numpy(y_val).ravel()
305
+ coefs = _to_numpy(coefs)
306
+
307
+ # Validate dimensions
308
+ if coefs.ndim != 2:
309
+ raise ValueError(f"coefs must be 2D (n_models, n_features), got shape {coefs.shape}")
310
+ if X_val.ndim != 2:
311
+ raise ValueError(f"X_val must be 2D (n_samples, n_features), got shape {X_val.shape}")
312
+ if coefs.shape[1] != X_val.shape[1]:
313
+ raise ValueError(
314
+ f"Feature dimension mismatch: coefs has {coefs.shape[1]} features, "
315
+ f"X_val has {X_val.shape[1]} features"
316
+ )
317
+ if y_val.shape[0] != X_val.shape[0]:
318
+ raise ValueError(
319
+ f"Sample count mismatch: y has {y_val.shape[0]} samples, "
320
+ f"X_val has {X_val.shape[0]} samples"
321
+ )
322
+ n_models = coefs.shape[0]
323
+
324
+ if intercepts is not None:
325
+ intercepts = _to_numpy(intercepts)
326
+
327
+ if sample_weight is not None:
328
+ sw = _to_numpy(sample_weight).ravel()
329
+ sw_sum = float(np.sum(sw))
330
+ else:
331
+ sw = None
332
+ sw_sum = 0.0
333
+
334
+ mse = np.empty(n_models, dtype=np.float64)
335
+
336
+ # Process in chunks to limit peak memory
337
+ for start in range(0, n_models, chunk_size):
338
+ end = min(start + chunk_size, n_models)
339
+ coefs_chunk = coefs[start:end]
340
+
341
+ # y_pred shape: (chunk_size, n_val)
342
+ y_pred = coefs_chunk @ X_val.T
343
+ if intercepts is not None:
344
+ y_pred = y_pred + intercepts[start:end, None]
345
+
346
+ residuals = y_val[None, :] - y_pred # (chunk_size, n_val)
347
+
348
+ if sw is not None:
349
+ if sw_sum > 0:
350
+ mse[start:end] = np.sum(residuals ** 2 * sw[None, :], axis=1) / sw_sum
351
+ else:
352
+ mse[start:end] = np.nan
353
+ else:
354
+ mse[start:end] = np.mean(residuals ** 2, axis=1)
355
+
356
+ return mse
357
+
358
+
359
+ # ---------------------------------------------------------------------------
360
+ # Base class
361
+ # ---------------------------------------------------------------------------
362
+
363
+ class CVEstimatorBase(BaseEstimator):
364
+ """
365
+ Common scaffolding for model-specific CV estimators.
366
+
367
+ This is intentionally lightweight: each model keeps its own CV search
368
+ routine and fitted attributes, while shared plumbing lives here.
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ *,
374
+ cv: int = 5,
375
+ random_state: Optional[int] = None,
376
+ device: Union[str, Device] = Device.AUTO,
377
+ n_jobs: Optional[int] = None,
378
+ ):
379
+ super().__init__(device=device, n_jobs=n_jobs)
380
+ self.cv = int(cv)
381
+ if self.cv < 2:
382
+ raise ValueError(f"cv must be >= 2, got {self.cv}")
383
+ self.random_state = random_state
384
+
385
+ # Common fitted attributes for CV estimators.
386
+ self.best_score_ = None
387
+ self.cv_results_ = None
388
+ self.estimator_ = None
389
+
390
+ def predict(self, X):
391
+ self._check_is_fitted()
392
+ if self.estimator_ is None:
393
+ raise RuntimeError("No fitted base estimator is available.")
394
+ return self.estimator_.predict(X)
395
+
396
+ def score(self, X, y):
397
+ self._check_is_fitted()
398
+ if self.estimator_ is None:
399
+ raise RuntimeError("No fitted base estimator is available.")
400
+ return self.estimator_.score(X, y)
401
+
402
+ def summary(self):
403
+ self._check_is_fitted()
404
+ if self.estimator_ is None:
405
+ raise RuntimeError("No fitted base estimator is available.")
406
+ if not hasattr(self.estimator_, "summary"):
407
+ raise RuntimeError(
408
+ f"{self.estimator_.__class__.__name__} does not implement summary()."
409
+ )
410
+ return self.estimator_.summary()
@@ -0,0 +1,167 @@
1
+ """
2
+ Generic cross-validation engine for penalized GLM models.
3
+
4
+ Provides a reusable CV loop that can be parameterized by:
5
+ - Any loss function (squared_error, logistic, poisson, etc.)
6
+ - Any penalty type (l1, l2, elasticnet, scad, mcp, etc.)
7
+ - Any backend (numpy, cupy, torch)
8
+
9
+ .. note::
10
+
11
+ **Reference Implementation**: ``run_cv`` is a simple, readable reference
12
+ implementation intended for:
13
+ - Custom estimators that need a basic CV loop
14
+ - Testing and prototyping new CV strategies
15
+ - Documentation of the CV algorithm
16
+
17
+ The production CV paths (PenalizedGLM_CV, LassoCV, RidgeCV, etc.) use
18
+ their own optimized loops with warm-starting, fold batching, and
19
+ backend-specific optimizations. For production use, prefer those
20
+ estimators directly.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ __all__ = ["run_cv"]
26
+
27
+ import logging
28
+ from typing import Any, Callable, List, Optional, Tuple
29
+
30
+ import numpy as np
31
+
32
+ from statgpu.cross_validation._base import (
33
+ CVCache,
34
+ kfold_indices,
35
+ )
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ def run_cv(
41
+ X,
42
+ y,
43
+ alpha_grid: np.ndarray,
44
+ evaluate_fold_fn: Callable,
45
+ n_folds: int = 5,
46
+ random_state: Optional[int] = None,
47
+ minimize: bool = True,
48
+ cache: Optional[CVCache] = None,
49
+ cache_key_fn: Optional[Callable] = None,
50
+ sample_weight=None,
51
+ raise_on_error: bool = False,
52
+ ) -> Tuple[float, np.ndarray, np.ndarray]:
53
+ """Execute K-fold cross-validation.
54
+
55
+ Parameters
56
+ ----------
57
+ X : array, shape (n_samples, n_features)
58
+ Feature matrix.
59
+ y : array, shape (n_samples,)
60
+ Target vector.
61
+ alpha_grid : array, shape (n_alphas,)
62
+ Regularization parameter grid.
63
+ evaluate_fold_fn : callable
64
+ Function ``(X_train, y_train, X_val, y_val, alpha,
65
+ sample_weight_train=None, sample_weight_val=None) -> score``
66
+ that trains on the training fold and returns a scalar score on
67
+ the validation fold.
68
+ n_folds : int
69
+ Number of CV folds.
70
+ random_state : int or None
71
+ Random seed for fold generation.
72
+ minimize : bool
73
+ If True, lower score is better. If False, higher score is better.
74
+ cache : CVCache or None
75
+ Optional LRU cache for CV results.
76
+ cache_key_fn : callable or None
77
+ Function ``(X, y, alpha_grid, folds) -> str`` for cache key.
78
+ sample_weight : array or None
79
+ Optional sample weights (passed through to evaluate_fold_fn).
80
+ raise_on_error : bool, default False
81
+ If True, re-raise exceptions from evaluate_fold_fn instead of
82
+ logging a warning and setting the score to NaN.
83
+
84
+ Returns
85
+ -------
86
+ best_alpha : float
87
+ Alpha value that optimizes the CV score.
88
+ mean_scores : array, shape (n_alphas,)
89
+ Mean CV score for each alpha.
90
+ all_scores : array, shape (n_folds, n_alphas,)
91
+ Per-fold CV scores.
92
+ """
93
+ # 0. Validate inputs
94
+ n_samples = X.shape[0]
95
+ if y.shape[0] != n_samples:
96
+ raise ValueError(f"X and y have different number of samples: {n_samples} vs {y.shape[0]}")
97
+ if len(alpha_grid) == 0:
98
+ raise ValueError("alpha_grid must not be empty")
99
+ if sample_weight is not None and len(sample_weight) != n_samples:
100
+ raise ValueError(
101
+ f"sample_weight length {len(sample_weight)} != n_samples {n_samples}"
102
+ )
103
+
104
+ # 1. Generate folds
105
+ folds = kfold_indices(n_samples, n_folds, random_state)
106
+
107
+ # 2. Check cache
108
+ cache_key = None
109
+ if cache is not None and cache_key_fn is not None:
110
+ cache_key = cache_key_fn(X, y, alpha_grid, folds)
111
+ cached = cache.get(cache_key)
112
+ if cached is not None:
113
+ return cached
114
+
115
+ # 3. Evaluate each (fold, alpha) pair
116
+ n_alphas = len(alpha_grid)
117
+ all_scores = np.full((n_folds, n_alphas), np.nan)
118
+
119
+ for fold_idx, (train_idx, val_idx) in enumerate(folds):
120
+ X_train = X[train_idx]
121
+ y_train = y[train_idx]
122
+ X_val = X[val_idx]
123
+ y_val = y[val_idx]
124
+
125
+ sw_train = sample_weight[train_idx] if sample_weight is not None else None
126
+ sw_val = sample_weight[val_idx] if sample_weight is not None else None
127
+
128
+ for alpha_idx, alpha in enumerate(alpha_grid):
129
+ try:
130
+ score = evaluate_fold_fn(
131
+ X_train, y_train, X_val, y_val, alpha,
132
+ sample_weight_train=sw_train,
133
+ sample_weight_val=sw_val,
134
+ )
135
+ all_scores[fold_idx, alpha_idx] = score
136
+ except (ValueError, FloatingPointError, np.linalg.LinAlgError, RuntimeError) as exc:
137
+ if raise_on_error:
138
+ raise
139
+ all_scores[fold_idx, alpha_idx] = np.nan
140
+ logger.warning(
141
+ "CV fold %d, alpha_idx %d failed: %s",
142
+ fold_idx, alpha_idx, exc,
143
+ )
144
+
145
+ # 4. Aggregate across folds
146
+ mean_scores = np.nanmean(all_scores, axis=0)
147
+
148
+ # Guard against all-NaN slices (all folds failed for every alpha)
149
+ finite_mask = np.isfinite(mean_scores)
150
+ if not np.any(finite_mask):
151
+ raise ValueError(
152
+ "All CV scores are NaN — every fold failed for every alpha. "
153
+ "Check for data issues or increase max_iter."
154
+ )
155
+
156
+ if minimize:
157
+ best_idx = int(np.nanargmin(mean_scores))
158
+ else:
159
+ best_idx = int(np.nanargmax(mean_scores))
160
+
161
+ best_alpha = float(alpha_grid[best_idx])
162
+
163
+ # 5. Cache results (copy arrays to prevent mutation corruption)
164
+ if cache is not None and cache_key_fn is not None:
165
+ cache.put(cache_key, (best_alpha, mean_scores.copy(), all_scores.copy()))
166
+
167
+ return best_alpha, mean_scores, all_scores
@@ -0,0 +1,7 @@
1
+ """
2
+ Diagnostics for regression models.
3
+ """
4
+
5
+ from ._regression_diagnostics import RegressionDiagnostics, diagnose_model
6
+
7
+ __all__ = ['RegressionDiagnostics', 'diagnose_model']