statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
1
+ """Empirical covariance estimation with GPU support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ __all__ = ["EmpiricalCovariance"]
6
+
7
+ from typing import Optional, Union
8
+
9
+ import numpy as np
10
+
11
+ from statgpu._base import BaseEstimator
12
+ from statgpu._config import Device
13
+ from statgpu.backends import (
14
+ _LINALG_ERRORS,
15
+ _get_xp,
16
+ _is_cupy_array,
17
+ _is_torch_array,
18
+ _resolve_backend,
19
+ _to_float_scalar,
20
+ _to_numpy,
21
+ _torch_dev,
22
+ xp_zeros,
23
+ xp_asarray,
24
+ )
25
+
26
+
27
+ def _detect_backend(X, device: Device) -> str:
28
+ """Resolve backend from input array type, falling back to device setting."""
29
+ if _is_torch_array(X):
30
+ return "torch"
31
+ if _is_cupy_array(X):
32
+ return "cupy"
33
+ # For numpy input, use device-based resolution
34
+ if device == Device.TORCH:
35
+ return "torch"
36
+ if device == Device.CUDA:
37
+ try:
38
+ import cupy as cp # noqa: F401
39
+ return "cupy"
40
+ except ImportError:
41
+ raise RuntimeError(
42
+ "CuPy is required for device='cuda' but is not installed. "
43
+ "Use device='auto' to fall back to CPU automatically."
44
+ )
45
+ return "numpy"
46
+
47
+
48
+ def _torch_device_from_data(X) -> Optional[str]:
49
+ """Extract torch device from tensor, or None for non-torch inputs."""
50
+ try:
51
+ import torch
52
+ if isinstance(X, torch.Tensor):
53
+ return str(X.device)
54
+ except (ImportError, AttributeError):
55
+ pass
56
+ return None
57
+
58
+
59
+ class EmpiricalCovariance(BaseEstimator):
60
+ """
61
+ Maximum likelihood covariance estimator with GPU acceleration.
62
+
63
+ Computes the sample covariance matrix, its inverse (precision), and
64
+ provides log-likelihood scoring and Mahalanobis distance computation.
65
+
66
+ Parameters
67
+ ----------
68
+ assume_centered : bool, default=False
69
+ If True, data is assumed to be already centered. If False, the
70
+ mean is estimated and subtracted before computing the covariance.
71
+ device : str or Device, default='auto'
72
+ Computation device: ``'cpu'``, ``'cuda'``, ``'torch'``, or ``'auto'``.
73
+ n_jobs : int or None, default=None
74
+ Number of parallel jobs (reserved for future use).
75
+
76
+ Attributes
77
+ ----------
78
+ covariance_ : array, shape (n_features, n_features)
79
+ Estimated covariance matrix.
80
+ location_ : array, shape (n_features,)
81
+ Estimated location (mean) vector.
82
+ precision_ : array, shape (n_features, n_features)
83
+ Estimated precision matrix (inverse covariance).
84
+ n_samples_ : int
85
+ Number of samples seen during fit.
86
+ n_features_ : int
87
+ Number of features seen during fit.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ assume_centered: bool = False,
93
+ device: Union[str, Device] = Device.AUTO,
94
+ n_jobs: Optional[int] = None,
95
+ ):
96
+ super().__init__(device=device, n_jobs=n_jobs)
97
+ self.assume_centered = assume_centered
98
+
99
+ def fit(self, X, y=None):
100
+ """Fit the covariance model to *X*.
101
+
102
+ Parameters
103
+ ----------
104
+ X : array-like of shape (n_samples, n_features)
105
+ Training data.
106
+ y : ignored
107
+ Not used, present for API compatibility.
108
+
109
+ Returns
110
+ -------
111
+ self
112
+ """
113
+ backend_name = _detect_backend(X, self._get_compute_device())
114
+ xp = _get_xp(backend_name)
115
+
116
+ # For torch backend, ensure arrays land on CUDA (not CPU)
117
+ _ref = None
118
+ if backend_name == "torch":
119
+ import torch
120
+ _dev = self._get_compute_device()
121
+ _cuda_dev = "cuda" if _dev.value in ("torch", "cuda") else "cpu"
122
+ _ref = torch.empty(0, dtype=torch.float64, device=_cuda_dev)
123
+
124
+ X_arr = xp_asarray(X, dtype=xp.float64, xp=xp, ref_arr=_ref)
125
+ if X_arr.ndim == 1:
126
+ X_arr = X_arr.reshape(-1, 1)
127
+
128
+ n_samples = int(X_arr.shape[0])
129
+ n_features = int(X_arr.shape[1])
130
+
131
+ if n_samples < 2:
132
+ raise ValueError(
133
+ f"Need at least 2 samples to estimate covariance, got {n_samples}"
134
+ )
135
+
136
+ # Center if needed
137
+ if self.assume_centered:
138
+ location = xp_zeros(n_features, xp.float64, xp, X_arr)
139
+ else:
140
+ location = xp.mean(X_arr, axis=0)
141
+ X_arr = X_arr - location
142
+
143
+ # Sample covariance: S = X^T X / n
144
+ S = (X_arr.T @ X_arr) / float(n_samples)
145
+
146
+ # Compute precision (inverse) with jitter stabilization
147
+ precision = _stable_inv(S, xp, backend_name)
148
+
149
+ self.covariance_ = S
150
+ self.location_ = location
151
+ self.precision_ = precision
152
+ self.n_samples_ = n_samples
153
+ self.n_features_ = n_features
154
+ self._backend_name = backend_name
155
+ self._fitted = True
156
+ return self
157
+
158
+ def predict(self, X):
159
+ """Return Mahalanobis distances for *X* under the fitted model.
160
+
161
+ Parameters
162
+ ----------
163
+ X : array-like of shape (n_samples, n_features)
164
+
165
+ Returns
166
+ -------
167
+ distances : ndarray of shape (n_samples,)
168
+ """
169
+ return self.mahalanobis(X)
170
+
171
+ def score(self, X, y=None):
172
+ """Compute the average log-likelihood of *X* under the fitted Gaussian.
173
+
174
+ Parameters
175
+ ----------
176
+ X : array-like of shape (n_samples, n_features)
177
+ Test data.
178
+ y : ignored
179
+
180
+ Returns
181
+ -------
182
+ ll : float
183
+ Average log-likelihood per observation.
184
+ """
185
+ self._check_is_fitted()
186
+ backend_name = _detect_backend(X, self._get_compute_device())
187
+ xp = _get_xp(backend_name)
188
+
189
+ X_arr = xp_asarray(X, dtype=xp.float64, xp=xp)
190
+ if X_arr.ndim == 1:
191
+ X_arr = X_arr.reshape(-1, 1)
192
+
193
+ n_samples = int(X_arr.shape[0])
194
+ p = int(X_arr.shape[1])
195
+
196
+ loc = xp_asarray(self.location_, dtype=xp.float64, xp=xp, ref_arr=X_arr)
197
+ prec = xp_asarray(self.precision_, dtype=xp.float64, xp=xp, ref_arr=X_arr)
198
+ cov = xp_asarray(self.covariance_, dtype=xp.float64, xp=xp, ref_arr=X_arr)
199
+
200
+ X_centered = X_arr - loc
201
+
202
+ # Mahalanobis term: sum of (x-mu)^T S^{-1} (x-mu)
203
+ M = X_centered @ prec
204
+ mahal_sum = _to_float_scalar(xp.sum(M * X_centered))
205
+
206
+ # log(det(S)) via slogdet for numerical stability
207
+ sign, logdet = xp.linalg.slogdet(cov)
208
+ logdet_val = _to_float_scalar(logdet)
209
+
210
+ # Average log-likelihood:
211
+ # LL = -(1/2) * (p * log(2*pi) + log(det(S)) + (1/n) * sum(mahal))
212
+ ll = -0.5 * (p * np.log(2.0 * np.pi) + logdet_val + mahal_sum / n_samples)
213
+ return float(ll)
214
+
215
+ def mahalanobis(self, X):
216
+ """Compute Mahalanobis distances of observations in *X*.
217
+
218
+ Parameters
219
+ ----------
220
+ X : array-like of shape (n_samples, n_features)
221
+
222
+ Returns
223
+ -------
224
+ distances : ndarray of shape (n_samples,)
225
+ Squared Mahalanobis distances.
226
+ """
227
+ self._check_is_fitted()
228
+ backend_name = _detect_backend(X, self._get_compute_device())
229
+ xp = _get_xp(backend_name)
230
+
231
+ X_arr = xp_asarray(X, dtype=xp.float64, xp=xp)
232
+ if X_arr.ndim == 1:
233
+ X_arr = X_arr.reshape(1, -1)
234
+
235
+ loc = xp_asarray(self.location_, dtype=xp.float64, xp=xp, ref_arr=X_arr)
236
+ prec = xp_asarray(self.precision_, dtype=xp.float64, xp=xp, ref_arr=X_arr)
237
+
238
+ X_centered = X_arr - loc
239
+
240
+ # Efficient: row-wise (x-mu)^T prec (x-mu)
241
+ M = X_centered @ prec
242
+ mahal = xp.sum(M * X_centered, axis=1)
243
+
244
+ return _to_numpy(mahal)
245
+
246
+ def get_params(self, deep=True):
247
+ params = super().get_params(deep=deep)
248
+ params["assume_centered"] = self.assume_centered
249
+ return params
250
+
251
+ def set_params(self, **params):
252
+ for key, value in list(params.items()):
253
+ if key == "assume_centered":
254
+ self.assume_centered = value
255
+ del params[key]
256
+ if params:
257
+ super().set_params(**params)
258
+ return self
259
+
260
+
261
+ # ---------------------------------------------------------------------------
262
+ # Internal helpers
263
+ # ---------------------------------------------------------------------------
264
+
265
+ def _stable_inv(S, xp, backend_name: str):
266
+ """Invert *S* with jitter-boosted diagonal for numerical stability.
267
+
268
+ Tries the exact inverse first; if that fails or produces non-finite
269
+ values, adds progressively larger diagonal jitter.
270
+ """
271
+ p = int(S.shape[0])
272
+
273
+ trace_S = _to_float_scalar(xp.trace(S))
274
+ base = max(abs(trace_S) / max(p, 1), 1.0) * 1e-10
275
+
276
+ torch_dev = None
277
+ if backend_name == "torch":
278
+ try:
279
+ import torch
280
+ if isinstance(S, torch.Tensor):
281
+ torch_dev = S.device
282
+ except (ImportError, AttributeError):
283
+ pass
284
+
285
+ # Pre-allocate identity matrix once
286
+ if torch_dev is not None:
287
+ eye = xp.eye(p, dtype=xp.float64, device=torch_dev)
288
+ else:
289
+ eye = xp.eye(p, dtype=xp.float64)
290
+
291
+ jitter = base
292
+ for _ in range(12):
293
+ try:
294
+ if jitter > 0:
295
+ S_work = S + jitter * eye
296
+ else:
297
+ S_work = S
298
+
299
+ inv_S = xp.linalg.inv(S_work)
300
+ test_val = _to_float_scalar(xp.max(xp.abs(inv_S)))
301
+ if np.isfinite(test_val):
302
+ return inv_S
303
+ except _LINALG_ERRORS + (ValueError,):
304
+ pass
305
+ jitter *= 10.0
306
+
307
+ raise ValueError(
308
+ "Covariance matrix is singular and cannot be inverted even with "
309
+ "diagonal jitter. Consider using LedoitWolf or OAS shrinkage."
310
+ )
@@ -0,0 +1,248 @@
1
+ """Ledoit-Wolf and Oracle Approximating Shrinkage (OAS) covariance estimators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ __all__ = ["LedoitWolf", "OAS"]
6
+
7
+ from typing import Optional, Union
8
+
9
+ import numpy as np
10
+
11
+ from statgpu._config import Device
12
+ from statgpu.backends import _get_xp, _to_float_scalar, _torch_dev, xp_zeros, xp_eye
13
+
14
+ from statgpu.covariance._empirical import EmpiricalCovariance, _detect_backend, _stable_inv
15
+
16
+
17
+ class LedoitWolf(EmpiricalCovariance):
18
+ """
19
+ Ledoit-Wolf shrinkage covariance estimator with GPU support.
20
+
21
+ Computes a shrinkage estimator that is a convex combination of the
22
+ sample covariance and a structured target (scaled identity). The
23
+ optimal shrinkage intensity is derived from the Ledoit & Wolf (2004)
24
+ analytical formula.
25
+
26
+ Parameters
27
+ ----------
28
+ assume_centered : bool, default=False
29
+ If True, data is assumed to be already centered.
30
+ device : str or Device, default='auto'
31
+ Computation device: ``'cpu'``, ``'cuda'``, ``'torch'``, or ``'auto'``.
32
+ n_jobs : int or None, default=None
33
+ Number of parallel jobs (reserved for future use).
34
+
35
+ Attributes
36
+ ----------
37
+ covariance_ : ndarray of shape (n_features, n_features)
38
+ Estimated shrunk covariance matrix.
39
+ location_ : ndarray of shape (n_features,)
40
+ Estimated mean (zero if *assume_centered* is True).
41
+ precision_ : ndarray of shape (n_features, n_features)
42
+ Estimated pseudo-inverse of the covariance (precision matrix).
43
+ shrinkage_ : float
44
+ Shrinkage intensity in [0, 1].
45
+ n_samples_ : int
46
+ Number of training samples.
47
+ n_features_ : int
48
+ Number of features.
49
+ """
50
+
51
+ def fit(self, X, y=None):
52
+ """Fit the Ledoit-Wolf covariance model to *X*.
53
+
54
+ Parameters
55
+ ----------
56
+ X : array-like of shape (n_samples, n_features)
57
+ Training data.
58
+ y : ignored
59
+
60
+ Returns
61
+ -------
62
+ self
63
+ """
64
+ backend_name = _detect_backend(X, self._get_compute_device())
65
+ xp = _get_xp(backend_name)
66
+
67
+ # Ensure torch tensors land on CUDA
68
+ _ref = None
69
+ if backend_name == "torch":
70
+ import torch
71
+ _ref = torch.empty(0, dtype=torch.float64, device="cuda")
72
+ if _ref is not None:
73
+ X_arr = xp.asarray(X, dtype=xp.float64, device=_ref.device)
74
+ else:
75
+ X_arr = xp.asarray(X, dtype=xp.float64)
76
+ if X_arr.ndim == 1:
77
+ X_arr = X_arr.reshape(-1, 1)
78
+
79
+ n = int(X_arr.shape[0])
80
+ p = int(X_arr.shape[1])
81
+
82
+ if n < 2:
83
+ raise ValueError(
84
+ f"Need at least 2 samples to estimate covariance, got {n}"
85
+ )
86
+
87
+ # Center
88
+ if self.assume_centered:
89
+ location = xp_zeros(p, xp.float64, xp, X_arr)
90
+ else:
91
+ location = xp.mean(X_arr, axis=0)
92
+ X_arr = X_arr - location
93
+
94
+ # Sample covariance
95
+ S = (X_arr.T @ X_arr) / float(n)
96
+
97
+ # ---- Ledoit-Wolf shrinkage intensity ----
98
+ mu = _to_float_scalar(xp.trace(S)) / p
99
+
100
+ # Efficient formula (LW2004):
101
+ # beta = (1/n^2) * [sum_k ||X_k||_2^4 - n * ||S||_F^2]
102
+ # delta = ||S - mu*I||_F^2 = ||S||_F^2 - tr(S)^2 / p
103
+ # alpha = clip(beta / delta, 0, 1)
104
+ X_sq = X_arr * X_arr
105
+ norm_sq = xp.sum(X_sq, axis=1) # ||X_k||^2 for each observation k
106
+ sum_norm_sq_sq = _to_float_scalar(xp.sum(norm_sq * norm_sq))
107
+
108
+ frob_S_sq = _to_float_scalar(xp.sum(S * S))
109
+ tr_S = _to_float_scalar(xp.trace(S))
110
+
111
+ beta = (sum_norm_sq_sq - float(n) * frob_S_sq) / (float(n) * float(n))
112
+ delta = frob_S_sq - tr_S * tr_S / float(p)
113
+
114
+ if delta <= 0.0:
115
+ # Degenerate case: all eigenvalues equal
116
+ alpha = 1.0
117
+ else:
118
+ alpha = beta / delta
119
+ alpha = max(0.0, min(1.0, alpha))
120
+
121
+ # Shrunk covariance: (1 - alpha) * S + alpha * mu * I
122
+ shrunk_S = (1.0 - alpha) * S + alpha * mu * xp_eye(p, xp.float64, xp, S)
123
+
124
+ # Precision of shrunk covariance
125
+ precision = _stable_inv(shrunk_S, xp, backend_name)
126
+
127
+ self.covariance_ = shrunk_S
128
+ self.location_ = location
129
+ self.precision_ = precision
130
+ self.shrinkage_ = alpha
131
+ self.n_samples_ = n
132
+ self.n_features_ = p
133
+ self._backend_name = backend_name
134
+ self._fitted = True
135
+ return self
136
+
137
+
138
+ class OAS(EmpiricalCovariance):
139
+ """
140
+ Oracle Approximating Shrinkage (OAS) covariance estimator with GPU support.
141
+
142
+ Uses the analytical formula from Chen, Wiesel, Eldar & Hero (2010)
143
+ to compute the optimal shrinkage intensity under a Gaussian assumption.
144
+
145
+ Parameters
146
+ ----------
147
+ assume_centered : bool, default=False
148
+ If True, data is assumed to be already centered.
149
+ device : str or Device, default='auto'
150
+ Computation device: ``'cpu'``, ``'cuda'``, ``'torch'``, or ``'auto'``.
151
+ n_jobs : int or None, default=None
152
+ Number of parallel jobs (reserved for future use).
153
+
154
+ Attributes
155
+ ----------
156
+ covariance_ : ndarray of shape (n_features, n_features)
157
+ Estimated shrunk covariance matrix.
158
+ location_ : ndarray of shape (n_features,)
159
+ Estimated mean (zero if *assume_centered* is True).
160
+ precision_ : ndarray of shape (n_features, n_features)
161
+ Estimated pseudo-inverse of the covariance (precision matrix).
162
+ shrinkage_ : float
163
+ Shrinkage intensity in [0, 1].
164
+ n_samples_ : int
165
+ Number of training samples.
166
+ n_features_ : int
167
+ Number of features.
168
+ """
169
+
170
+ def fit(self, X, y=None):
171
+ """Fit the OAS covariance model to *X*.
172
+
173
+ Parameters
174
+ ----------
175
+ X : array-like of shape (n_samples, n_features)
176
+ Training data.
177
+ y : ignored
178
+
179
+ Returns
180
+ -------
181
+ self
182
+ """
183
+ backend_name = _detect_backend(X, self._get_compute_device())
184
+ xp = _get_xp(backend_name)
185
+
186
+ # Ensure torch tensors land on CUDA
187
+ _ref = None
188
+ if backend_name == "torch":
189
+ import torch
190
+ _ref = torch.empty(0, dtype=torch.float64, device="cuda")
191
+ if _ref is not None:
192
+ X_arr = xp.asarray(X, dtype=xp.float64, device=_ref.device)
193
+ else:
194
+ X_arr = xp.asarray(X, dtype=xp.float64)
195
+ if X_arr.ndim == 1:
196
+ X_arr = X_arr.reshape(-1, 1)
197
+
198
+ n = int(X_arr.shape[0])
199
+ p = int(X_arr.shape[1])
200
+
201
+ if n < 2:
202
+ raise ValueError(
203
+ f"Need at least 2 samples to estimate covariance, got {n}"
204
+ )
205
+
206
+ # Center
207
+ if self.assume_centered:
208
+ location = xp_zeros(p, xp.float64, xp, X_arr)
209
+ else:
210
+ location = xp.mean(X_arr, axis=0)
211
+ X_arr = X_arr - location
212
+
213
+ # Sample covariance
214
+ S = (X_arr.T @ X_arr) / float(n)
215
+
216
+ # ---- OAS shrinkage intensity ----
217
+ # Follows sklearn's implementation of the OAS formula (Chen et al. 2010).
218
+ # Note: sklearn omits the 2/p factor from Eq. 23 in the original paper
219
+ # because it negligibly affects the estimator for large p.
220
+ tr_S = _to_float_scalar(xp.trace(S))
221
+ alpha_mean = _to_float_scalar(xp.mean(S * S)) # mean of squared elements
222
+ mu = tr_S / float(p)
223
+ mu_squared = mu * mu
224
+
225
+ numerator = alpha_mean + mu_squared
226
+ denominator = (float(n) + 1.0) * (alpha_mean - mu_squared / float(p))
227
+
228
+ if denominator <= 0.0:
229
+ alpha = 1.0
230
+ else:
231
+ alpha = numerator / denominator
232
+ alpha = max(0.0, min(1.0, alpha))
233
+
234
+ # Shrunk covariance: (1 - alpha) * S + alpha * mu * I
235
+ shrunk_S = (1.0 - alpha) * S + alpha * mu * xp_eye(p, xp.float64, xp, S)
236
+
237
+ # Precision of shrunk covariance
238
+ precision = _stable_inv(shrunk_S, xp, backend_name)
239
+
240
+ self.covariance_ = shrunk_S
241
+ self.location_ = location
242
+ self.precision_ = precision
243
+ self.shrinkage_ = alpha
244
+ self.n_samples_ = n
245
+ self.n_features_ = p
246
+ self._backend_name = backend_name
247
+ self._fitted = True
248
+ return self
@@ -0,0 +1,31 @@
1
+ """Generic cross-validation framework.
2
+
3
+ Provides CVEstimatorBase, kfold_indices, hash_cv_data, batch_mse,
4
+ and other CV utilities. Used by linear_model, survival, and other modules.
5
+ """
6
+
7
+ from ._base import (
8
+ CVEstimatorBase,
9
+ CVCache,
10
+ kfold_indices,
11
+ folds_are_complete,
12
+ hash_cv_data,
13
+ validate_cv_sample_weight,
14
+ batch_mse,
15
+ detect_gpu_input,
16
+ INTERCEPT_CLIP_BOUND,
17
+ )
18
+ from ._engine import run_cv
19
+
20
+ __all__ = [
21
+ "CVEstimatorBase",
22
+ "CVCache",
23
+ "kfold_indices",
24
+ "folds_are_complete",
25
+ "hash_cv_data",
26
+ "validate_cv_sample_weight",
27
+ "batch_mse",
28
+ "detect_gpu_input",
29
+ "INTERCEPT_CLIP_BOUND",
30
+ "run_cv",
31
+ ]