statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Safe torch import wrapper for Torch 2.8+ compatibility.
|
|
2
|
+
|
|
3
|
+
Torch 2.8.0+ may raise RuntimeError('Only a single TORCH_LIBRARY can be
|
|
4
|
+
used to register the namespace prims') when imported in environments where
|
|
5
|
+
torch has already been loaded (e.g., Jupyter kernels, other processes).
|
|
6
|
+
|
|
7
|
+
This module provides a safe import that catches this error and marks torch
|
|
8
|
+
as unavailable. All torch imports in statgpu should go through this module
|
|
9
|
+
via: from statgpu.backends._torch_safe import get_torch
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
_torch = None
|
|
13
|
+
_torch_available = None # None = not checked, True/False = checked
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_torch():
|
|
17
|
+
"""Return the torch module, or None if not available.
|
|
18
|
+
|
|
19
|
+
Catches RuntimeError from TORCH_LIBRARY registration conflicts
|
|
20
|
+
that occur on Torch 2.8+ in environments with pre-existing torch state.
|
|
21
|
+
"""
|
|
22
|
+
global _torch, _torch_available
|
|
23
|
+
|
|
24
|
+
if _torch_available is True:
|
|
25
|
+
return _torch
|
|
26
|
+
if _torch_available is False:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import torch
|
|
31
|
+
_torch = torch
|
|
32
|
+
_torch_available = True
|
|
33
|
+
return _torch
|
|
34
|
+
except (ImportError, RuntimeError) as e:
|
|
35
|
+
# RuntimeError: TORCH_LIBRARY conflict on Torch 2.8+
|
|
36
|
+
# ImportError: torch not installed
|
|
37
|
+
_torch = None
|
|
38
|
+
_torch_available = False
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def torch_available():
|
|
43
|
+
"""Check if torch is available without importing it."""
|
|
44
|
+
global _torch_available
|
|
45
|
+
if _torch_available is None:
|
|
46
|
+
get_torch()
|
|
47
|
+
return _torch_available
|
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""General-purpose backend utility functions.
|
|
2
|
+
|
|
3
|
+
These helpers are used across statgpu submodules to avoid duplicating
|
|
4
|
+
array-library detection, module resolution, and scalar conversion logic.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
__all__ = ["xp_zeros", "xp_eye", "xp_full", "xp_astype", "xp_asarray", "xp_empty", "torch_compile_supported"]
|
|
10
|
+
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
# Exception types raised by linalg operations on singular/ill-conditioned matrices.
|
|
16
|
+
# numpy raises LinAlgError; torch raises RuntimeError for linalg failures.
|
|
17
|
+
# NOTE: torch RuntimeError is overly broad (also catches OOM, autograd errors).
|
|
18
|
+
# Callers should re-raise if the error message doesn't match linalg patterns.
|
|
19
|
+
# NOTE: We do NOT import torch at module level to avoid TORCH_LIBRARY
|
|
20
|
+
# registration conflicts on Torch 2.8+. Torch is imported lazily via _safe_import_torch().
|
|
21
|
+
_LINALG_ERRORS: tuple = (np.linalg.LinAlgError,)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _safe_import_torch():
|
|
25
|
+
"""Import torch safely, catching TORCH_LIBRARY registration conflicts.
|
|
26
|
+
|
|
27
|
+
Torch 2.8+ may raise RuntimeError when imported in environments where
|
|
28
|
+
torch has already been loaded (Jupyter kernels, other processes).
|
|
29
|
+
Returns the torch module, or None if import fails.
|
|
30
|
+
"""
|
|
31
|
+
try:
|
|
32
|
+
import torch
|
|
33
|
+
return torch
|
|
34
|
+
except (ImportError, RuntimeError):
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Module-level check: is torch available?
|
|
39
|
+
_TORCH_AVAILABLE = None # None = not checked yet
|
|
40
|
+
_TORCH_MODULE = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _ensure_torch():
|
|
44
|
+
"""Ensure torch is imported and available. Returns torch module or None."""
|
|
45
|
+
global _TORCH_AVAILABLE, _TORCH_MODULE, _LINALG_ERRORS
|
|
46
|
+
if _TORCH_AVAILABLE is None:
|
|
47
|
+
_TORCH_MODULE = _safe_import_torch()
|
|
48
|
+
_TORCH_AVAILABLE = _TORCH_MODULE is not None
|
|
49
|
+
if _TORCH_AVAILABLE:
|
|
50
|
+
_LINALG_ERRORS = (np.linalg.LinAlgError, RuntimeError)
|
|
51
|
+
return _TORCH_MODULE
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _require_torch():
|
|
55
|
+
"""Import torch or raise ImportError with clear message."""
|
|
56
|
+
torch = _ensure_torch()
|
|
57
|
+
if torch is None:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"Torch is not available. This may be due to a TORCH_LIBRARY "
|
|
60
|
+
"registration conflict (Torch 2.8+) or missing installation."
|
|
61
|
+
)
|
|
62
|
+
return torch
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_xp(backend_name: str):
|
|
66
|
+
"""Return the array module (numpy / cupy / torch) for *backend_name*.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
backend_name : str
|
|
71
|
+
One of ``'numpy'``, ``'cupy'``, or ``'torch'``.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
module
|
|
76
|
+
The array module (``numpy``, ``cupy``, or ``torch``).
|
|
77
|
+
|
|
78
|
+
Raises
|
|
79
|
+
------
|
|
80
|
+
ValueError
|
|
81
|
+
If *backend_name* is not recognised.
|
|
82
|
+
ImportError
|
|
83
|
+
If the requested library is not installed.
|
|
84
|
+
"""
|
|
85
|
+
if backend_name == "numpy":
|
|
86
|
+
return np
|
|
87
|
+
if backend_name == "cupy":
|
|
88
|
+
try:
|
|
89
|
+
import cupy as cp
|
|
90
|
+
|
|
91
|
+
return cp
|
|
92
|
+
except ImportError as exc:
|
|
93
|
+
raise ImportError(
|
|
94
|
+
"backend='cupy' requires CuPy, but CuPy is not installed"
|
|
95
|
+
) from exc
|
|
96
|
+
if backend_name == "torch":
|
|
97
|
+
try:
|
|
98
|
+
torch = _require_torch()
|
|
99
|
+
|
|
100
|
+
return torch
|
|
101
|
+
except ImportError as exc:
|
|
102
|
+
raise ImportError(
|
|
103
|
+
"backend='torch' requires PyTorch, but PyTorch is not installed"
|
|
104
|
+
) from exc
|
|
105
|
+
raise ValueError(f"Unsupported backend: {backend_name}")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _to_numpy(x):
|
|
109
|
+
"""Convert *x* to a ``numpy.ndarray``.
|
|
110
|
+
|
|
111
|
+
Handles CuPy arrays (``.get()``) and PyTorch tensors (``.cpu().numpy()``).
|
|
112
|
+
"""
|
|
113
|
+
if hasattr(x, "get"):
|
|
114
|
+
return x.get()
|
|
115
|
+
if hasattr(x, "cpu") and hasattr(x, "numpy"):
|
|
116
|
+
return x.detach().cpu().numpy() if hasattr(x, 'detach') else x.cpu().numpy()
|
|
117
|
+
return np.asarray(x)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _to_float_scalar(x: Any) -> float:
|
|
121
|
+
"""Extract a Python ``float`` from a backend array scalar."""
|
|
122
|
+
if hasattr(x, "item"):
|
|
123
|
+
return float(x.item())
|
|
124
|
+
return float(x)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _get_torch_device_str() -> str:
|
|
128
|
+
"""Return ``'cuda'`` if PyTorch CUDA is available, else ``'cpu'``."""
|
|
129
|
+
try:
|
|
130
|
+
torch = _require_torch()
|
|
131
|
+
|
|
132
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
133
|
+
except ImportError:
|
|
134
|
+
return "cpu"
|
|
135
|
+
except Exception as e:
|
|
136
|
+
import warnings
|
|
137
|
+
warnings.warn(f"torch.cuda.is_available() failed, falling back to CPU: {e}")
|
|
138
|
+
return "cpu"
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _torch_on_target_device(tensor, device: Optional[str]) -> bool:
|
|
142
|
+
"""Return True when a torch tensor is already on the requested device."""
|
|
143
|
+
if device is None:
|
|
144
|
+
return True
|
|
145
|
+
device = str(device)
|
|
146
|
+
tensor_device = str(getattr(tensor, "device", ""))
|
|
147
|
+
# "cuda" means any CUDA device; "cuda:0", "cuda:1" etc. require exact match
|
|
148
|
+
if device == "cuda":
|
|
149
|
+
return getattr(tensor, "device", None).type == "cuda"
|
|
150
|
+
return tensor_device == device
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _move_torch_tensor(tensor, device: Optional[str] = None, dtype=None, pin_memory: bool = False):
|
|
154
|
+
"""Move/cast a torch tensor, using pinned non-blocking H2D when useful."""
|
|
155
|
+
torch = _require_torch()
|
|
156
|
+
|
|
157
|
+
if dtype is not None and not isinstance(dtype, torch.dtype):
|
|
158
|
+
try:
|
|
159
|
+
dtype = getattr(torch, np.dtype(dtype).name)
|
|
160
|
+
except Exception:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
target = device or _get_torch_device_str()
|
|
164
|
+
needs_move = not _torch_on_target_device(tensor, target)
|
|
165
|
+
needs_dtype = dtype is not None and tensor.dtype != dtype
|
|
166
|
+
if not needs_move and not needs_dtype:
|
|
167
|
+
return tensor
|
|
168
|
+
|
|
169
|
+
if pin_memory and str(target).startswith("cuda") and tensor.device.type == "cpu":
|
|
170
|
+
try:
|
|
171
|
+
pinned = tensor.pin_memory() if not tensor.is_pinned() else tensor
|
|
172
|
+
kwargs = {"device": target, "non_blocking": True}
|
|
173
|
+
if dtype is not None:
|
|
174
|
+
kwargs["dtype"] = dtype
|
|
175
|
+
return pinned.to(**kwargs)
|
|
176
|
+
except Exception:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
kwargs = {}
|
|
180
|
+
if device is not None:
|
|
181
|
+
kwargs["device"] = target
|
|
182
|
+
if dtype is not None:
|
|
183
|
+
kwargs["dtype"] = dtype
|
|
184
|
+
return tensor.to(**kwargs) if kwargs else tensor
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _numpy_to_torch_tensor(x, device: Optional[str] = None, dtype=None, pin_memory: bool = False):
|
|
188
|
+
"""Convert NumPy-like input to torch, preserving contiguous fast paths."""
|
|
189
|
+
torch = _require_torch()
|
|
190
|
+
|
|
191
|
+
arr = np.asarray(x)
|
|
192
|
+
if not arr.flags["C_CONTIGUOUS"]:
|
|
193
|
+
arr = np.ascontiguousarray(arr)
|
|
194
|
+
tensor = torch.from_numpy(arr)
|
|
195
|
+
return _move_torch_tensor(tensor, device=device, dtype=dtype, pin_memory=pin_memory)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _cupy_to_torch_dlpack(x, device: Optional[str] = None):
|
|
199
|
+
"""Convert a CuPy array to torch through DLPack, returning None if unsupported."""
|
|
200
|
+
try:
|
|
201
|
+
import cupy as cp
|
|
202
|
+
torch = _require_torch()
|
|
203
|
+
|
|
204
|
+
if not isinstance(x, cp.ndarray):
|
|
205
|
+
return None
|
|
206
|
+
try:
|
|
207
|
+
tensor = torch.utils.dlpack.from_dlpack(x)
|
|
208
|
+
except TypeError:
|
|
209
|
+
tensor = torch.utils.dlpack.from_dlpack(x.toDlpack())
|
|
210
|
+
return _move_torch_tensor(tensor, device=device)
|
|
211
|
+
except Exception:
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _torch_to_cupy_dlpack(x):
|
|
216
|
+
"""Convert a CUDA torch tensor to CuPy through DLPack, returning None if unsupported."""
|
|
217
|
+
try:
|
|
218
|
+
import cupy as cp
|
|
219
|
+
torch = _require_torch()
|
|
220
|
+
|
|
221
|
+
if not isinstance(x, torch.Tensor) or not x.is_cuda:
|
|
222
|
+
return None
|
|
223
|
+
tensor = x.detach()
|
|
224
|
+
if not tensor.is_contiguous():
|
|
225
|
+
tensor = tensor.contiguous()
|
|
226
|
+
try:
|
|
227
|
+
return cp.from_dlpack(tensor)
|
|
228
|
+
except Exception:
|
|
229
|
+
return cp.fromDlpack(torch.utils.dlpack.to_dlpack(tensor))
|
|
230
|
+
except Exception:
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# ---------------------------------------------------------------------------
|
|
235
|
+
# Device-aware array creation helpers
|
|
236
|
+
# ---------------------------------------------------------------------------
|
|
237
|
+
|
|
238
|
+
def _torch_dev(arr):
|
|
239
|
+
"""Extract device from a torch tensor, or ``None`` for non-torch arrays."""
|
|
240
|
+
try:
|
|
241
|
+
torch = _require_torch()
|
|
242
|
+
if isinstance(arr, torch.Tensor):
|
|
243
|
+
return arr.device
|
|
244
|
+
except (ImportError, AttributeError):
|
|
245
|
+
pass
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def xp_zeros(shape, dtype, xp, ref_arr=None):
|
|
250
|
+
"""Device-aware ``xp.zeros``. *ref_arr* provides the target device."""
|
|
251
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
252
|
+
if dev is not None:
|
|
253
|
+
return xp.zeros(shape, dtype=dtype, device=dev)
|
|
254
|
+
return xp.zeros(shape, dtype=dtype)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def xp_eye(n, dtype, xp, ref_arr=None):
|
|
258
|
+
"""Device-aware ``xp.eye``. *ref_arr* provides the target device."""
|
|
259
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
260
|
+
if dev is not None:
|
|
261
|
+
return xp.eye(n, dtype=dtype, device=dev)
|
|
262
|
+
return xp.eye(n, dtype=dtype)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def xp_full(shape, fill_value, dtype, xp, ref_arr=None):
|
|
266
|
+
"""Device-aware ``xp.full`` with int→tuple normalisation."""
|
|
267
|
+
if isinstance(shape, int):
|
|
268
|
+
shape = (shape,)
|
|
269
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
270
|
+
if dev is not None:
|
|
271
|
+
return xp.full(shape, fill_value, dtype=dtype, device=dev)
|
|
272
|
+
return xp.full(shape, fill_value, dtype=dtype)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _np_dtype_to_torch(dtype):
|
|
276
|
+
"""Convert a numpy dtype to the equivalent torch dtype."""
|
|
277
|
+
torch = _require_torch()
|
|
278
|
+
_MAP = {
|
|
279
|
+
'float32': torch.float32,
|
|
280
|
+
'float64': torch.float64,
|
|
281
|
+
'float16': torch.float16,
|
|
282
|
+
'int32': torch.int32,
|
|
283
|
+
'int64': torch.int64,
|
|
284
|
+
'int16': torch.int16,
|
|
285
|
+
'int8': torch.int8,
|
|
286
|
+
'uint8': torch.uint8,
|
|
287
|
+
'bool': torch.bool,
|
|
288
|
+
}
|
|
289
|
+
key = str(np.dtype(dtype)).split('.')[-1]
|
|
290
|
+
result = _MAP.get(key)
|
|
291
|
+
if result is None:
|
|
292
|
+
import warnings
|
|
293
|
+
warnings.warn(f"Unknown numpy dtype '{dtype}' for torch conversion, falling back to float64", stacklevel=2)
|
|
294
|
+
return torch.float64
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _torch_dtype_to_np(dtype):
|
|
299
|
+
"""Convert a torch dtype to the equivalent numpy dtype."""
|
|
300
|
+
torch = _require_torch()
|
|
301
|
+
_MAP = {
|
|
302
|
+
torch.float32: np.dtype('float32'),
|
|
303
|
+
torch.float64: np.dtype('float64'),
|
|
304
|
+
torch.float16: np.dtype('float16'),
|
|
305
|
+
torch.int32: np.dtype('int32'),
|
|
306
|
+
torch.int64: np.dtype('int64'),
|
|
307
|
+
torch.int16: np.dtype('int16'),
|
|
308
|
+
torch.int8: np.dtype('int8'),
|
|
309
|
+
torch.uint8: np.dtype('uint8'),
|
|
310
|
+
torch.bool: np.dtype('bool'),
|
|
311
|
+
}
|
|
312
|
+
return _MAP.get(dtype, np.dtype('float64'))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def xp_astype(arr, dtype, xp=None):
|
|
316
|
+
"""Backend-safe type cast (``.to()`` for torch, ``.astype()`` otherwise).
|
|
317
|
+
|
|
318
|
+
Note: ``xp`` parameter is unused — backend is detected from ``arr`` directly.
|
|
319
|
+
Kept for backward compatibility with existing callers.
|
|
320
|
+
"""
|
|
321
|
+
if _torch_dev(arr) is not None:
|
|
322
|
+
torch = _require_torch()
|
|
323
|
+
if not isinstance(dtype, torch.dtype):
|
|
324
|
+
dtype = _np_dtype_to_torch(dtype)
|
|
325
|
+
return arr.to(dtype)
|
|
326
|
+
return arr.astype(dtype)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def xp_asarray(data, dtype=None, xp=None, ref_arr=None):
|
|
330
|
+
"""Device-aware ``xp.asarray``. *ref_arr* provides the target device."""
|
|
331
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
332
|
+
if dev is not None:
|
|
333
|
+
kwargs = {'device': dev}
|
|
334
|
+
if dtype is not None:
|
|
335
|
+
kwargs['dtype'] = dtype
|
|
336
|
+
return xp.asarray(data, **kwargs)
|
|
337
|
+
if dtype is not None:
|
|
338
|
+
return xp.asarray(data, dtype=dtype)
|
|
339
|
+
return xp.asarray(data)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def xp_empty(shape, dtype, xp, ref_arr=None):
|
|
343
|
+
"""Device-aware ``xp.empty``. *ref_arr* provides the target device."""
|
|
344
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
345
|
+
if dev is not None:
|
|
346
|
+
return xp.empty(shape, dtype=dtype, device=dev)
|
|
347
|
+
return xp.empty(shape, dtype=dtype)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def xp_arange(n, dtype=None, xp=None, ref_arr=None):
|
|
351
|
+
"""Device-aware ``xp.arange``. *ref_arr* provides the target device."""
|
|
352
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
353
|
+
if dev is not None:
|
|
354
|
+
kwargs = {'device': dev}
|
|
355
|
+
if dtype is not None:
|
|
356
|
+
kwargs['dtype'] = dtype
|
|
357
|
+
return xp.arange(n, **kwargs)
|
|
358
|
+
if dtype is not None:
|
|
359
|
+
return xp.arange(n, dtype=dtype)
|
|
360
|
+
return xp.arange(n)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def xp_ones(shape, dtype, xp, ref_arr=None):
|
|
364
|
+
"""Device-aware ``xp.ones``. *ref_arr* provides the target device."""
|
|
365
|
+
dev = _torch_dev(ref_arr) if ref_arr is not None else None
|
|
366
|
+
if dev is not None:
|
|
367
|
+
return xp.ones(shape, dtype=dtype, device=dev)
|
|
368
|
+
return xp.ones(shape, dtype=dtype)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def xp_maximum(arr, value, xp=None):
|
|
372
|
+
"""Element-wise maximum that works for both numpy/cupy and torch.
|
|
373
|
+
|
|
374
|
+
Torch's ``maximum()`` requires both args to be tensors; numpy/cupy accept
|
|
375
|
+
scalars. This helper wraps *value* as needed.
|
|
376
|
+
"""
|
|
377
|
+
if _torch_dev(arr) is not None:
|
|
378
|
+
torch = _require_torch()
|
|
379
|
+
if not isinstance(value, torch.Tensor):
|
|
380
|
+
value = torch.tensor(value, dtype=arr.dtype, device=arr.device)
|
|
381
|
+
return torch.maximum(arr, value)
|
|
382
|
+
return xp.maximum(arr, value) if xp is not None else np.maximum(arr, value)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def xp_copy(arr):
|
|
386
|
+
"""Backend-safe copy (``.clone()`` for torch, ``.copy()`` otherwise)."""
|
|
387
|
+
if _torch_dev(arr) is not None:
|
|
388
|
+
return arr.clone()
|
|
389
|
+
return arr.copy()
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def xp_cholesky_solve(A, b, xp):
|
|
393
|
+
"""Solve ``A @ x = b`` via Cholesky decomposition.
|
|
394
|
+
|
|
395
|
+
Works across numpy, cupy, and torch backends. Handles the torch-specific
|
|
396
|
+
argument difference for ``solve_triangular`` (``upper=False`` vs ``lower=True``).
|
|
397
|
+
For cupy, uses general solve (no solve_triangular in cupy).
|
|
398
|
+
For numpy, uses scipy.linalg.solve_triangular.
|
|
399
|
+
"""
|
|
400
|
+
if hasattr(A, 'get'): # CuPy: no solve_triangular, use general solve directly
|
|
401
|
+
return xp.linalg.solve(A, b)
|
|
402
|
+
L = xp.linalg.cholesky(A)
|
|
403
|
+
if _torch_dev(L) is not None:
|
|
404
|
+
tmp = xp.linalg.solve_triangular(L, b, upper=False)
|
|
405
|
+
return xp.linalg.solve_triangular(L.T, tmp, upper=True)
|
|
406
|
+
# numpy: use scipy for solve_triangular
|
|
407
|
+
from scipy.linalg import solve_triangular
|
|
408
|
+
tmp = solve_triangular(L, b, lower=True)
|
|
409
|
+
return solve_triangular(L.T, tmp, lower=False)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def torch_compile_supported():
|
|
413
|
+
"""Check if torch.compile is safe to use (CUDA Capability >= 7.0)."""
|
|
414
|
+
try:
|
|
415
|
+
torch = _require_torch()
|
|
416
|
+
if torch.cuda.is_available():
|
|
417
|
+
cap = torch.cuda.get_device_capability()
|
|
418
|
+
return cap[0] >= 7
|
|
419
|
+
except ImportError:
|
|
420
|
+
return False # torch not installed
|
|
421
|
+
except Exception:
|
|
422
|
+
pass
|
|
423
|
+
return False # Can't verify — assume not supported
|
statgpu/core/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
statgpu.core.formula – R-style formula interface for statgpu models.
|
|
3
|
+
|
|
4
|
+
This module provides formula-based model fitting similar to statsmodels/patsy::
|
|
5
|
+
|
|
6
|
+
>>> import statgpu as sg
|
|
7
|
+
>>> model = sg.LinearRegression()
|
|
8
|
+
>>> model.fit(formula="y ~ x1 + x2 + C(cat)", data=df)
|
|
9
|
+
>>> model.summary()
|
|
10
|
+
|
|
11
|
+
The formula syntax is parsed by `patsy` (optional dependency). Install with::
|
|
12
|
+
|
|
13
|
+
pip install statgpu[formula]
|
|
14
|
+
|
|
15
|
+
Public API
|
|
16
|
+
----------
|
|
17
|
+
FormulaParser
|
|
18
|
+
Main class for parsing R-style formulas and building design matrices.
|
|
19
|
+
parse_formula
|
|
20
|
+
Convenience function for one-shot formula evaluation.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from ._parser import FormulaParser
|
|
24
|
+
from ._design import parse_formula, parse_formula_safe
|
|
25
|
+
from ._terms import make_surv_env, _surv
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"FormulaParser",
|
|
29
|
+
"parse_formula",
|
|
30
|
+
"parse_formula_safe",
|
|
31
|
+
"make_surv_env",
|
|
32
|
+
"_surv",
|
|
33
|
+
]
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Design matrix building utilities.
|
|
3
|
+
|
|
4
|
+
Provides convenience function for one-shot formula evaluation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Tuple, Optional, Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from ._parser import FormulaParser
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse_formula(
|
|
16
|
+
formula: str,
|
|
17
|
+
data: pd.DataFrame,
|
|
18
|
+
) -> Tuple[np.ndarray, np.ndarray, Any]:
|
|
19
|
+
"""One-shot convenience function for formula parsing.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
formula : str
|
|
24
|
+
R-style formula string, e.g. ``"y ~ x1 + x2"``.
|
|
25
|
+
data : pd.DataFrame
|
|
26
|
+
DataFrame containing the referenced columns.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
y : ndarray
|
|
31
|
+
Response variable(s).
|
|
32
|
+
X : ndarray
|
|
33
|
+
Predictor design matrix.
|
|
34
|
+
design_info : patsy.DesignInfo
|
|
35
|
+
Metadata for the predictor design.
|
|
36
|
+
|
|
37
|
+
Examples
|
|
38
|
+
--------
|
|
39
|
+
>>> import pandas as pd
|
|
40
|
+
>>> df = pd.DataFrame({"y": [1, 2, 3], "x": [4, 5, 6]})
|
|
41
|
+
>>> y, X, info = parse_formula("y ~ x", df)
|
|
42
|
+
"""
|
|
43
|
+
parser = FormulaParser(formula)
|
|
44
|
+
return parser.eval(data)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def parse_formula_safe(
|
|
48
|
+
formula: Optional[str],
|
|
49
|
+
data: Optional[pd.DataFrame],
|
|
50
|
+
X: Optional[np.ndarray] = None,
|
|
51
|
+
y: Optional[np.ndarray] = None,
|
|
52
|
+
) -> Tuple[np.ndarray, np.ndarray, Optional[Any]]:
|
|
53
|
+
"""Safe formula parsing that falls back to raw arrays.
|
|
54
|
+
|
|
55
|
+
Used by model ``fit()`` methods to support both formula and array interfaces.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
formula : str or None
|
|
60
|
+
R-style formula string. If ``None``, ``X`` and ``y`` are used directly.
|
|
61
|
+
data : pd.DataFrame or None
|
|
62
|
+
DataFrame for formula parsing. Required when ``formula`` is given.
|
|
63
|
+
X : ndarray or None
|
|
64
|
+
Raw predictor matrix (used when ``formula`` is ``None``).
|
|
65
|
+
y : ndarray or None
|
|
66
|
+
Raw response vector (used when ``formula`` is ``None``).
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
y : ndarray
|
|
71
|
+
Response variable(s).
|
|
72
|
+
X : ndarray
|
|
73
|
+
Predictor design matrix.
|
|
74
|
+
design_info : patsy.DesignInfo or None
|
|
75
|
+
Design metadata (``None`` when raw arrays are used).
|
|
76
|
+
|
|
77
|
+
Raises
|
|
78
|
+
------
|
|
79
|
+
ValueError
|
|
80
|
+
If both formula and arrays are ``None``, or if formula is given without data.
|
|
81
|
+
"""
|
|
82
|
+
if formula is not None:
|
|
83
|
+
if data is None:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"formula was provided but data (DataFrame) is None. "
|
|
86
|
+
"When using formula, pass data=your_dataframe."
|
|
87
|
+
)
|
|
88
|
+
return parse_formula(formula, data)
|
|
89
|
+
|
|
90
|
+
if X is None or y is None:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"Either formula+data or X+y must be provided. "
|
|
93
|
+
"Got formula=None and incomplete array input."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
y = np.asarray(y)
|
|
97
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
98
|
+
y = y.ravel()
|
|
99
|
+
return y, np.asarray(X), None
|