statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,348 @@
1
+ """Shared utilities for kernel-based nonparametric estimators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Any, Union
7
+
8
+ import numpy as np
9
+
10
+ from statgpu.backends import (
11
+ _get_torch_device_str,
12
+ _get_xp,
13
+ _resolve_backend,
14
+ _to_float_scalar,
15
+ _to_numpy,
16
+ _torch_dev,
17
+ xp_arange,
18
+ xp_asarray,
19
+ xp_astype,
20
+ xp_empty,
21
+ xp_eye,
22
+ xp_full,
23
+ xp_maximum,
24
+ xp_ones,
25
+ xp_zeros,
26
+ )
27
+
28
+ # Re-export for backward compatibility
29
+ __all__ = [
30
+ "_auto_backend_from_device",
31
+ "_as_points_2d",
32
+ "_as_samples_2d",
33
+ "_bandwidth_factor",
34
+ "_bandwidth_factor_1d_nrd",
35
+ "_effective_sample_size",
36
+ "_get_xp",
37
+ "_kernel_values_from_quad",
38
+ "_normalize_kernel_name",
39
+ "_normalize_regression_name",
40
+ "_normalize_weights",
41
+ "_resolve_backend",
42
+ "_stable_inv_and_det",
43
+ "_to_float_scalar",
44
+ "_to_numpy",
45
+ "_weighted_covariance",
46
+ ]
47
+
48
+
49
+ def _torch_device_from_data(data) -> str:
50
+ """Extract device string from a torch tensor, or return 'cpu' for others."""
51
+ try:
52
+ import torch
53
+ if isinstance(data, torch.Tensor):
54
+ return str(data.device)
55
+ except (ImportError, AttributeError):
56
+ pass
57
+ return "cpu"
58
+
59
+
60
+ def _auto_backend_from_device(device: str, prefer_torch: bool = False) -> str:
61
+ d = str(device).strip().lower()
62
+ if d in ("numpy", "cpu"):
63
+ return "numpy"
64
+ if d == "torch":
65
+ return "torch"
66
+ if d in ("cuda", "gpu"):
67
+ # Check if Torch is available and has CUDA
68
+ if prefer_torch:
69
+ try:
70
+ import torch
71
+ if torch.cuda.is_available():
72
+ return "torch"
73
+ except Exception:
74
+ pass
75
+ # Otherwise try CuPy
76
+ try:
77
+ import cupy as cp
78
+ _ = int(cp.cuda.runtime.getDeviceCount())
79
+ return "cupy"
80
+ except Exception:
81
+ # Fallback to Torch if CuPy unavailable
82
+ if not prefer_torch:
83
+ try:
84
+ import torch
85
+ if torch.cuda.is_available():
86
+ return "torch"
87
+ except Exception:
88
+ pass
89
+ raise RuntimeError(
90
+ f"No GPU backend (CuPy or Torch CUDA) is available for "
91
+ f"device='{device}'. Use device='auto' to fall back to CPU."
92
+ )
93
+ # Default: prefer CuPy, then Torch, then NumPy
94
+ try:
95
+ import cupy as cp
96
+ _ = int(cp.cuda.runtime.getDeviceCount())
97
+ return "cupy"
98
+ except Exception:
99
+ try:
100
+ import torch
101
+ if torch.cuda.is_available():
102
+ return "torch"
103
+ except Exception:
104
+ pass
105
+ return "numpy"
106
+
107
+
108
+ def _normalize_kernel_name(kernel: str) -> str:
109
+ name = str(kernel).strip().lower()
110
+ aliases = {
111
+ "gaussian": "gaussian",
112
+ "normal": "gaussian",
113
+ "rectangular": "rectangular",
114
+ "uniform": "rectangular",
115
+ "box": "rectangular",
116
+ "triangular": "triangular",
117
+ "epanechnikov": "epanechnikov",
118
+ "epa": "epanechnikov",
119
+ "biweight": "biweight",
120
+ "quartic": "biweight",
121
+ "triweight": "triweight",
122
+ "cosine": "cosine",
123
+ "optcosine": "optcosine",
124
+ }
125
+ normalized = aliases.get(name)
126
+ if normalized is None:
127
+ raise ValueError(
128
+ "kernel must be one of: 'gaussian', 'rectangular', 'triangular', "
129
+ "'epanechnikov', 'biweight', 'triweight', 'cosine', 'optcosine'"
130
+ )
131
+ return normalized
132
+
133
+
134
+ def _normalize_regression_name(regression: str) -> str:
135
+ name = str(regression).strip().lower()
136
+ aliases = {
137
+ "nw": "nw",
138
+ "nadaraya_watson": "nw",
139
+ "nadaraya-watson": "nw",
140
+ "local_linear": "local_linear",
141
+ "local-linear": "local_linear",
142
+ "ll": "local_linear",
143
+ }
144
+ normalized = aliases.get(name)
145
+ if normalized is None:
146
+ raise ValueError(
147
+ "regression must be one of: 'nw', 'nadaraya_watson', 'local_linear', 'll'"
148
+ )
149
+ return normalized
150
+
151
+
152
+ def _kernel_values_from_quad(quad, kernel_name: str, xp):
153
+ if kernel_name == "gaussian":
154
+ return xp.exp(-0.5 * quad)
155
+
156
+ support_mask = quad <= 1.0
157
+ if kernel_name == "rectangular":
158
+ return xp_astype(support_mask, xp.float64, xp)
159
+
160
+ if kernel_name == "triangular":
161
+ return xp_maximum(1.0 - xp.sqrt(xp_maximum(quad, 0.0, xp)), 0.0, xp)
162
+
163
+ one_minus_quad = xp_maximum(1.0 - quad, 0.0, xp)
164
+ if kernel_name == "epanechnikov":
165
+ return one_minus_quad
166
+ if kernel_name == "biweight":
167
+ return one_minus_quad * one_minus_quad
168
+ if kernel_name == "triweight":
169
+ return one_minus_quad * one_minus_quad * one_minus_quad
170
+ if kernel_name == "cosine":
171
+ r = xp.sqrt(xp_maximum(quad, 0.0, xp))
172
+ return xp.where(support_mask, 0.5 * (1.0 + xp.cos(math.pi * r)), 0.0)
173
+ if kernel_name == "optcosine":
174
+ r = xp.sqrt(xp_maximum(quad, 0.0, xp))
175
+ return xp.where(support_mask, xp.cos(0.5 * math.pi * r), 0.0)
176
+
177
+ raise ValueError(f"Unsupported kernel: {kernel_name}")
178
+
179
+
180
+ def _as_samples_2d(samples, xp, ref_arr=None):
181
+ arr = xp_asarray(samples, dtype=xp.float64, xp=xp, ref_arr=ref_arr)
182
+ if arr.ndim == 1:
183
+ arr = arr.reshape(-1, 1)
184
+ elif arr.ndim != 2:
185
+ raise ValueError("samples must be 1D or 2D")
186
+
187
+ n_samples = int(arr.shape[0])
188
+ if n_samples < 2:
189
+ raise ValueError("samples must contain at least 2 observations")
190
+ return arr
191
+
192
+
193
+ def _as_points_2d(points, n_features: int, xp, ref_arr=None):
194
+ arr = xp_asarray(points, dtype=xp.float64, xp=xp, ref_arr=ref_arr)
195
+ if arr.ndim == 1:
196
+ if n_features == 1:
197
+ arr = arr.reshape(-1, 1)
198
+ elif int(arr.size) == n_features:
199
+ arr = arr.reshape(1, n_features)
200
+ else:
201
+ raise ValueError("points shape is incompatible with sample dimensionality")
202
+ elif arr.ndim != 2:
203
+ raise ValueError("points must be 1D or 2D")
204
+
205
+ if int(arr.shape[1]) != int(n_features):
206
+ raise ValueError("points feature dimension does not match samples")
207
+ return arr
208
+
209
+
210
+ def _normalize_weights(weights, n_samples: int, xp, device: str = "cpu", ref_arr=None):
211
+ if weights is None:
212
+ fill_val = 1.0 / float(n_samples)
213
+ return xp_full(n_samples, fill_val, xp.float64, xp, ref_arr=ref_arr)
214
+
215
+ w = xp_asarray(weights, dtype=xp.float64, xp=xp, ref_arr=ref_arr).reshape(-1)
216
+ if int(w.size) != int(n_samples):
217
+ raise ValueError("weights must have the same length as samples")
218
+ if _to_float_scalar(xp.min(w)) < 0.0:
219
+ raise ValueError("weights must be non-negative")
220
+
221
+ w_sum = xp.sum(w)
222
+ if _to_float_scalar(w_sum) <= 0.0:
223
+ raise ValueError("weights must sum to a positive value")
224
+
225
+ return w / w_sum
226
+
227
+
228
+ def _effective_sample_size(weights, xp) -> float:
229
+ w2 = xp.sum(weights * weights)
230
+ denom = _to_float_scalar(w2)
231
+ if denom <= 0.0:
232
+ raise ValueError("invalid weights: effective sample size denominator is non-positive")
233
+ return 1.0 / denom
234
+
235
+
236
+ def _bandwidth_factor_1d_nrd(
237
+ method: str,
238
+ *,
239
+ n_eff: float,
240
+ samples_2d,
241
+ data_cov,
242
+ xp,
243
+ ) -> float:
244
+ method_n = str(method).strip().lower()
245
+ if method_n not in ("nrd0", "nrd"):
246
+ raise ValueError("method must be one of: 'nrd0', 'nrd'")
247
+
248
+ x = np.asarray(_to_numpy(samples_2d[:, 0]), dtype=np.float64)
249
+ x = x[np.isfinite(x)]
250
+ if x.size < 2:
251
+ raise ValueError("need at least 2 finite samples for 'nrd0'/'nrd' bandwidth")
252
+
253
+ sd = float(np.std(x, ddof=1))
254
+ q75, q25 = np.quantile(x, [0.75, 0.25])
255
+ robust = float((q75 - q25) / 1.34)
256
+
257
+ scale = min(sd, robust) if np.isfinite(robust) and robust > 0.0 else sd
258
+ if (not np.isfinite(scale)) or scale <= 0.0:
259
+ scale = float(np.std(x, ddof=0))
260
+ if (not np.isfinite(scale)) or scale <= 0.0:
261
+ raise ValueError("unable to compute positive scale for 'nrd0'/'nrd' bandwidth")
262
+
263
+ coeff = 0.9 if method_n == "nrd0" else 1.06
264
+ bw_abs = float(coeff * scale * (float(n_eff) ** (-1.0 / 5.0)))
265
+ if (not np.isfinite(bw_abs)) or bw_abs <= 0.0:
266
+ raise ValueError("automatic bandwidth rule produced a non-positive value")
267
+
268
+ data_sd = math.sqrt(max(_to_float_scalar(data_cov[0, 0]), 0.0))
269
+ if data_sd <= 0.0 or (not np.isfinite(data_sd)):
270
+ data_sd = max(float(np.finfo(np.float64).tiny), sd)
271
+
272
+ factor = float(bw_abs / data_sd)
273
+ if (not np.isfinite(factor)) or factor <= 0.0:
274
+ raise ValueError("bandwidth factor must be a finite positive scalar")
275
+ return factor
276
+
277
+
278
+ def _bandwidth_factor(
279
+ bandwidth: Union[str, float, int],
280
+ *,
281
+ n_eff: float,
282
+ n_features: int,
283
+ ) -> float:
284
+ if isinstance(bandwidth, str):
285
+ method = bandwidth.strip().lower()
286
+ if method == "scott":
287
+ factor = n_eff ** (-1.0 / (n_features + 4.0))
288
+ elif method == "silverman":
289
+ factor = (n_eff * (n_features + 2.0) / 4.0) ** (-1.0 / (n_features + 4.0))
290
+ else:
291
+ raise ValueError(
292
+ "bandwidth must be one of: 'scott', 'silverman', 'nrd0', 'nrd', "
293
+ "'ucv', 'bcv', 'sj', 'sj-ste', 'sj-dpi', 'cv', 'cv_ls', 'cv-nw', 'cv-ll', "
294
+ "or a positive scalar"
295
+ )
296
+ else:
297
+ factor = float(bandwidth)
298
+
299
+ if not np.isfinite(factor) or factor <= 0.0:
300
+ raise ValueError("bandwidth factor must be a finite positive scalar")
301
+ return float(factor)
302
+
303
+
304
+ def _weighted_covariance(samples_2d, weights_1d, xp):
305
+ n_features = int(samples_2d.shape[1])
306
+
307
+ mean = xp.sum(samples_2d * weights_1d[:, None], axis=0)
308
+ centered = samples_2d - mean
309
+
310
+ denom = 1.0 - xp.sum(weights_1d * weights_1d)
311
+ denom_f = _to_float_scalar(denom)
312
+ if denom_f <= 1e-15:
313
+ raise ValueError("effective degrees of freedom is too small for covariance estimation")
314
+
315
+ cov = (centered.T * weights_1d[None, :]) @ centered / denom
316
+ cov = 0.5 * (cov + cov.T)
317
+
318
+ trace = _to_float_scalar(xp.trace(cov))
319
+ base = trace / float(max(1, n_features)) if np.isfinite(trace) else 1.0
320
+ jitter = max(base * 1e-12, 1e-12)
321
+ cov = cov + jitter * xp_eye(n_features, xp.float64, xp, ref_arr=cov)
322
+ return cov
323
+
324
+
325
+ def _stable_inv_and_det(cov, xp):
326
+ n_features = int(cov.shape[0])
327
+ cov_work = xp_astype(cov, xp.float64, xp)
328
+
329
+ trace = _to_float_scalar(xp.trace(cov_work))
330
+ base = trace / float(max(1, n_features)) if np.isfinite(trace) else 1.0
331
+ jitter = max(base * 1e-12, 1e-12)
332
+
333
+ last_err = None
334
+ for _ in range(8):
335
+ try:
336
+ inv_cov = xp.linalg.inv(cov_work)
337
+ det_cov = _to_float_scalar(xp.linalg.det(cov_work))
338
+ if np.isfinite(det_cov) and det_cov > 0.0:
339
+ return inv_cov, det_cov, cov_work
340
+ except Exception as exc:
341
+ last_err = exc
342
+
343
+ cov_work = cov_work + jitter * xp_eye(n_features, xp.float64, xp, ref_arr=cov_work)
344
+ jitter *= 10.0
345
+
346
+ if last_err is not None:
347
+ raise ValueError("covariance inversion failed") from last_err
348
+ raise ValueError("covariance matrix is not positive definite")