statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Breslow Hessian computation via PyTorch GPU operations (cuBLAS + vectorized).
|
|
2
|
+
|
|
3
|
+
Originally attempted a Triton serial-scan kernel, but Triton 2.0 has a compiler
|
|
4
|
+
bug that produces non-deterministic wrong code for kernels with runtime-bounded
|
|
5
|
+
loops (while/for with >= 3 iterations). The PyTorch approach is only marginally
|
|
6
|
+
slower since each GPU operation (matmul, outer) is highly optimized by cuBLAS.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
# Supported padded feature dimensions (next power of 2)
|
|
15
|
+
_SUPPORTED_P: Tuple[int, ...] = (8, 16, 32, 64, 128)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _find_p_ce(p: int) -> Optional[int]:
|
|
19
|
+
"""Find the smallest supported padded feature dimension >= p."""
|
|
20
|
+
for sp in _SUPPORTED_P:
|
|
21
|
+
if sp >= p:
|
|
22
|
+
return sp
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compute_breslow_grad_hess_triton(
|
|
27
|
+
X: Any,
|
|
28
|
+
beta: Any,
|
|
29
|
+
time: Any,
|
|
30
|
+
event: Any,
|
|
31
|
+
) -> Optional[Tuple[Any, Any]]:
|
|
32
|
+
"""Compute Breslow gradient/Hessian via PyTorch GPU operations.
|
|
33
|
+
|
|
34
|
+
Uses the same algorithm as _cox.py Breslow path: vectorized gradient,
|
|
35
|
+
then serial Python loop over unique failure times with async PyTorch
|
|
36
|
+
GPU operations for the Hessian.
|
|
37
|
+
"""
|
|
38
|
+
if not isinstance(X, torch.Tensor) or not isinstance(beta, torch.Tensor):
|
|
39
|
+
return None
|
|
40
|
+
if not X.is_cuda or not beta.is_cuda:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
p = int(X.shape[1])
|
|
44
|
+
p_ce = _find_p_ce(p)
|
|
45
|
+
if p_ce is None:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
n = int(X.shape[0])
|
|
49
|
+
device = X.device
|
|
50
|
+
|
|
51
|
+
# Linear predictor
|
|
52
|
+
eta = X @ beta
|
|
53
|
+
exp_eta = torch.exp(eta)
|
|
54
|
+
X_exp = X * exp_eta[:, None]
|
|
55
|
+
|
|
56
|
+
event_mask = (event == 1)
|
|
57
|
+
if not torch.any(event_mask):
|
|
58
|
+
return (
|
|
59
|
+
torch.zeros(p, dtype=torch.float64, device=device),
|
|
60
|
+
torch.zeros((p, p), dtype=torch.float64, device=device),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Reverse cumsum for risk sets
|
|
64
|
+
rev_idx = torch.arange(n - 1, -1, -1, device=device)
|
|
65
|
+
risk_sum = torch.cumsum(exp_eta[rev_idx], dim=0)[rev_idx]
|
|
66
|
+
risk_X_sum = torch.cumsum((X * exp_eta[:, None])[rev_idx], dim=0)[rev_idx]
|
|
67
|
+
|
|
68
|
+
# Unique failure times
|
|
69
|
+
event_times = time[event_mask]
|
|
70
|
+
uft, unique_inv = torch.unique(event_times, sorted=True, return_inverse=True)
|
|
71
|
+
n_uft = len(uft)
|
|
72
|
+
counts = torch.bincount(unique_inv).to(torch.float64)
|
|
73
|
+
|
|
74
|
+
sorted_times, sort_idx = torch.sort(time)
|
|
75
|
+
first_in_sorted = torch.searchsorted(sorted_times, uft, side="left")
|
|
76
|
+
first_idx = sort_idx[first_in_sorted]
|
|
77
|
+
|
|
78
|
+
# Precompute risk values at unique times
|
|
79
|
+
risk_at_uft = risk_sum[first_idx]
|
|
80
|
+
risk_X_at_uft = risk_X_sum[first_idx]
|
|
81
|
+
E_X_at_uft = risk_X_at_uft / risk_at_uft[:, None]
|
|
82
|
+
|
|
83
|
+
# Sum X for events at each unique time
|
|
84
|
+
event_indices = event_mask.nonzero(as_tuple=True)[0]
|
|
85
|
+
sum_X_per_uft = torch.zeros((n_uft, p), dtype=torch.float64, device=device)
|
|
86
|
+
sum_X_per_uft.index_add_(0, unique_inv, X[event_indices])
|
|
87
|
+
|
|
88
|
+
# Gradient: Breslow closed-form
|
|
89
|
+
grad = torch.sum(sum_X_per_uft - counts[:, None] * E_X_at_uft, dim=0)
|
|
90
|
+
|
|
91
|
+
# Hessian: PyTorch GPU operations (same algorithm as _cox.py)
|
|
92
|
+
risk_X2 = X_exp.T @ X
|
|
93
|
+
hess = torch.zeros((p, p), dtype=torch.float64, device=device)
|
|
94
|
+
pidx = 0
|
|
95
|
+
for g in range(n_uft):
|
|
96
|
+
idx = int(first_idx[g].item())
|
|
97
|
+
if idx > pidx:
|
|
98
|
+
risk_X2 -= X_exp[pidx:idx].T @ X[pidx:idx]
|
|
99
|
+
pidx = idx
|
|
100
|
+
rs = risk_at_uft[g]
|
|
101
|
+
w = counts[g]
|
|
102
|
+
ex = E_X_at_uft[g]
|
|
103
|
+
hess -= risk_X2 * (w / rs)
|
|
104
|
+
hess += torch.outer(ex, ex) * w
|
|
105
|
+
|
|
106
|
+
return grad, hess
|