statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PyTorch GPU/CPU backend.
|
|
3
|
+
|
|
4
|
+
PyTorch tensors do *not* mirror the NumPy array API 1:1 (e.g. ``torch.linalg``
|
|
5
|
+
vs ``numpy.linalg``, different dtypes, etc.). The ``xp`` property therefore
|
|
6
|
+
returns the ``torch`` module itself; callers that need NumPy-compatible ops
|
|
7
|
+
should use the helper methods on this class instead of ``xp.<op>`` directly.
|
|
8
|
+
|
|
9
|
+
Note on API compatibility
|
|
10
|
+
-------------------------
|
|
11
|
+
Model code should use ``backend.xp`` for basic operations like:
|
|
12
|
+
- ``backend.xp.sum``, ``backend.xp.matmul``, ``backend.xp.sqrt``, etc.
|
|
13
|
+
- ``backend.xp.linalg.solve``, ``backend.xp.linalg.cholesky``, etc.
|
|
14
|
+
|
|
15
|
+
For operations with API differences (e.g. ``axis`` vs ``dim``), use helper
|
|
16
|
+
methods on this backend class.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from statgpu.backends._base import BackendBase
|
|
22
|
+
from statgpu.backends._utils import (
|
|
23
|
+
_cupy_to_torch_dlpack,
|
|
24
|
+
_move_torch_tensor,
|
|
25
|
+
_numpy_to_torch_tensor,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Default CUDA device string used when moving tensors to GPU.
|
|
29
|
+
_DEFAULT_TORCH_DEVICE = "cuda"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TorchBackend(BackendBase):
|
|
33
|
+
"""
|
|
34
|
+
GPU (or CPU) backend powered by PyTorch.
|
|
35
|
+
|
|
36
|
+
Requires ``torch`` (install via ``pip install statgpu[torch]``).
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
device : str, default='cuda'
|
|
41
|
+
Torch device string, e.g. ``'cuda'``, ``'cuda:0'``, or ``'cpu'``.
|
|
42
|
+
|
|
43
|
+
Examples
|
|
44
|
+
--------
|
|
45
|
+
>>> from statgpu.backends import TorchBackend
|
|
46
|
+
>>> backend = TorchBackend(device='cuda')
|
|
47
|
+
>>> xp = backend.xp # torch module
|
|
48
|
+
>>> arr = backend.asarray([1, 2, 3])
|
|
49
|
+
>>> backend.to_numpy(arr)
|
|
50
|
+
array([1, 2, 3])
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
name = "torch"
|
|
54
|
+
|
|
55
|
+
def __init__(self, device: str = _DEFAULT_TORCH_DEVICE):
|
|
56
|
+
self._device = device
|
|
57
|
+
self._initialized = False
|
|
58
|
+
|
|
59
|
+
def _ensure_initialized(self):
|
|
60
|
+
"""Perform one-time CUDA warmup to avoid lazy kernel init penalty."""
|
|
61
|
+
if self._initialized:
|
|
62
|
+
return
|
|
63
|
+
if self._device != 'cpu':
|
|
64
|
+
import torch
|
|
65
|
+
if torch.cuda.is_available():
|
|
66
|
+
# Warmup: small matmul to trigger CUDA kernel initialization
|
|
67
|
+
_ = torch.randn(32, 32, device=self._device) @ torch.randn(32, 32, device=self._device)
|
|
68
|
+
torch.cuda.synchronize()
|
|
69
|
+
self._initialized = True
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def xp(self):
|
|
73
|
+
import torch # deferred import
|
|
74
|
+
return torch
|
|
75
|
+
|
|
76
|
+
def asarray(self, x, dtype=None):
|
|
77
|
+
"""
|
|
78
|
+
Convert x to Torch tensor on the configured device.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
x : array-like
|
|
83
|
+
Input data (list, numpy.ndarray, cupy.ndarray, or torch.Tensor).
|
|
84
|
+
dtype : torch.dtype, optional
|
|
85
|
+
Desired data type (e.g., torch.float64).
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
torch.Tensor
|
|
90
|
+
"""
|
|
91
|
+
import torch
|
|
92
|
+
self._ensure_initialized() # Warmup on first use to avoid lazy CUDA init
|
|
93
|
+
if isinstance(x, torch.Tensor):
|
|
94
|
+
t = _move_torch_tensor(x, device=self._device, dtype=dtype)
|
|
95
|
+
elif hasattr(x, "get"):
|
|
96
|
+
# CuPy arrays expose a .get() method that transfers the array from
|
|
97
|
+
# GPU memory to a NumPy ndarray on the host. Duck-typing avoids a
|
|
98
|
+
# mandatory cupy import here.
|
|
99
|
+
t = _cupy_to_torch_dlpack(x, device=self._device)
|
|
100
|
+
if t is None:
|
|
101
|
+
t = _numpy_to_torch_tensor(
|
|
102
|
+
x.get(),
|
|
103
|
+
device=self._device,
|
|
104
|
+
dtype=dtype,
|
|
105
|
+
pin_memory=self._device.startswith("cuda"),
|
|
106
|
+
)
|
|
107
|
+
elif dtype is not None:
|
|
108
|
+
t = _move_torch_tensor(t, dtype=dtype)
|
|
109
|
+
else:
|
|
110
|
+
# Use torch.from_numpy for numpy arrays, then ensure contiguous memory
|
|
111
|
+
t = _numpy_to_torch_tensor(
|
|
112
|
+
x,
|
|
113
|
+
device=self._device,
|
|
114
|
+
dtype=dtype,
|
|
115
|
+
pin_memory=self._device.startswith("cuda"),
|
|
116
|
+
)
|
|
117
|
+
# Ensure result is contiguous for optimal performance
|
|
118
|
+
if not t.is_contiguous():
|
|
119
|
+
t = t.contiguous()
|
|
120
|
+
return t
|
|
121
|
+
|
|
122
|
+
def to_numpy(self, x) -> np.ndarray:
|
|
123
|
+
"""
|
|
124
|
+
Convert Torch tensor to NumPy array.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
x : torch.Tensor or array-like
|
|
129
|
+
A native tensor produced by this backend (or any array-like).
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
numpy.ndarray
|
|
134
|
+
"""
|
|
135
|
+
import torch
|
|
136
|
+
if isinstance(x, torch.Tensor):
|
|
137
|
+
# Move to CPU first, then convert to numpy
|
|
138
|
+
return x.detach().cpu().numpy()
|
|
139
|
+
if hasattr(x, "get"):
|
|
140
|
+
# CuPy arrays expose a .get() method that transfers the array from
|
|
141
|
+
# GPU memory to a NumPy ndarray on the host.
|
|
142
|
+
return x.get()
|
|
143
|
+
return np.asarray(x)
|
|
144
|
+
|
|
145
|
+
def is_available(self) -> bool:
|
|
146
|
+
"""Return True if PyTorch can be used in the current environment."""
|
|
147
|
+
try:
|
|
148
|
+
import torch
|
|
149
|
+
# Allow CPU-based torch backend as well.
|
|
150
|
+
if self._device.startswith("cuda"):
|
|
151
|
+
return torch.cuda.is_available()
|
|
152
|
+
return True
|
|
153
|
+
except Exception:
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
# ------------------------------------------------------------------
|
|
157
|
+
# Override helpers to use torch.linalg
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def solve(self, A, b):
|
|
161
|
+
"""Solve the linear system Ax = b using torch.linalg.solve."""
|
|
162
|
+
import torch
|
|
163
|
+
return torch.linalg.solve(A, b)
|
|
164
|
+
|
|
165
|
+
def lstsq(self, A, b, rcond=None):
|
|
166
|
+
"""
|
|
167
|
+
Return the least-squares solution to Ax ≈ b.
|
|
168
|
+
|
|
169
|
+
torch.linalg.lstsq returns a named tuple; we unpack it for
|
|
170
|
+
compatibility with numpy's lstsq interface.
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
solution : torch.Tensor
|
|
175
|
+
residuals : torch.Tensor
|
|
176
|
+
rank : int
|
|
177
|
+
singular_values : torch.Tensor
|
|
178
|
+
"""
|
|
179
|
+
import torch
|
|
180
|
+
result = torch.linalg.lstsq(A, b)
|
|
181
|
+
return result.solution, result.residuals, result.rank, result.singular_values
|
|
182
|
+
|
|
183
|
+
def solve_triangular(self, A, b, lower=False, trans=False, unit_triangular=False):
|
|
184
|
+
"""
|
|
185
|
+
Solve the triangular system Ax = b.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
A : torch.Tensor
|
|
190
|
+
Triangular matrix (n, n).
|
|
191
|
+
b : torch.Tensor
|
|
192
|
+
Right-hand side (n,) or (n, k).
|
|
193
|
+
lower : bool, default=False
|
|
194
|
+
Whether to use the lower triangle of A.
|
|
195
|
+
trans : bool, default=False
|
|
196
|
+
Whether to transpose A.
|
|
197
|
+
unit_triangular : bool, default=False
|
|
198
|
+
Whether to assume the diagonal of A is all ones.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
x : torch.Tensor
|
|
203
|
+
Solution to the system.
|
|
204
|
+
"""
|
|
205
|
+
import torch
|
|
206
|
+
if isinstance(trans, str):
|
|
207
|
+
trans_flag = trans.upper() in ("T", "C")
|
|
208
|
+
else:
|
|
209
|
+
trans_flag = bool(trans)
|
|
210
|
+
if trans_flag:
|
|
211
|
+
A = A.transpose(-2, -1)
|
|
212
|
+
lower = not lower
|
|
213
|
+
return torch.linalg.solve_triangular(
|
|
214
|
+
A,
|
|
215
|
+
b,
|
|
216
|
+
upper=not lower,
|
|
217
|
+
unitriangular=bool(unit_triangular),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
# Additional Torch-native helpers for common operations
|
|
222
|
+
# ------------------------------------------------------------------
|
|
223
|
+
|
|
224
|
+
def sum(self, x, axis=None, keepdims=False):
|
|
225
|
+
"""
|
|
226
|
+
Sum over specified axis/axes.
|
|
227
|
+
|
|
228
|
+
Note: Torch uses 'dim' instead of 'axis'.
|
|
229
|
+
"""
|
|
230
|
+
import torch
|
|
231
|
+
if axis is None:
|
|
232
|
+
return torch.sum(x)
|
|
233
|
+
if isinstance(axis, int):
|
|
234
|
+
return torch.sum(x, dim=axis, keepdim=keepdims)
|
|
235
|
+
# Multiple axes: sum iteratively
|
|
236
|
+
for ax in sorted(axis, reverse=True):
|
|
237
|
+
x = torch.sum(x, dim=ax, keepdim=keepdims)
|
|
238
|
+
return x
|
|
239
|
+
|
|
240
|
+
def mean(self, x, axis=None, keepdims=False):
|
|
241
|
+
"""Mean over specified axis/axes."""
|
|
242
|
+
import torch
|
|
243
|
+
if axis is None:
|
|
244
|
+
return torch.mean(x)
|
|
245
|
+
if isinstance(axis, int):
|
|
246
|
+
return torch.mean(x, dim=axis, keepdim=keepdims)
|
|
247
|
+
# For multiple axes, compute manually
|
|
248
|
+
if isinstance(axis, (list, tuple)):
|
|
249
|
+
n_elem = 1
|
|
250
|
+
for ax in axis:
|
|
251
|
+
n_elem *= x.shape[ax]
|
|
252
|
+
return self.sum(x, axis=axis, keepdims=keepdims) / n_elem
|
|
253
|
+
return torch.mean(x, dim=axis, keepdim=keepdims)
|
|
254
|
+
|
|
255
|
+
def sqrt(self, x):
|
|
256
|
+
"""Element-wise square root."""
|
|
257
|
+
import torch
|
|
258
|
+
return torch.sqrt(x)
|
|
259
|
+
|
|
260
|
+
def abs(self, x):
|
|
261
|
+
"""Element-wise absolute value."""
|
|
262
|
+
import torch
|
|
263
|
+
return torch.abs(x)
|
|
264
|
+
|
|
265
|
+
def max(self, x, axis=None, keepdims=False):
|
|
266
|
+
"""Maximum value along axis."""
|
|
267
|
+
import torch
|
|
268
|
+
if axis is None:
|
|
269
|
+
return torch.max(x)
|
|
270
|
+
if isinstance(axis, int):
|
|
271
|
+
result = torch.max(x, dim=axis, keepdim=keepdims)
|
|
272
|
+
return result.values if hasattr(result, 'values') else result[0]
|
|
273
|
+
# Multiple axes: reduce iteratively
|
|
274
|
+
for ax in sorted(axis, reverse=True):
|
|
275
|
+
result = torch.max(x, dim=ax, keepdim=keepdims)
|
|
276
|
+
x = result.values if hasattr(result, 'values') else result[0]
|
|
277
|
+
return x
|
|
278
|
+
|
|
279
|
+
def square(self, x):
|
|
280
|
+
"""Element-wise square."""
|
|
281
|
+
import torch
|
|
282
|
+
return torch.square(x)
|
|
283
|
+
|
|
284
|
+
def exp(self, x):
|
|
285
|
+
"""Element-wise exponential."""
|
|
286
|
+
import torch
|
|
287
|
+
return torch.exp(x)
|
|
288
|
+
|
|
289
|
+
def log(self, x):
|
|
290
|
+
"""Element-wise natural logarithm."""
|
|
291
|
+
import torch
|
|
292
|
+
return torch.log(x)
|
|
293
|
+
|
|
294
|
+
def log1p(self, x):
|
|
295
|
+
"""Element-wise log(1 + x)."""
|
|
296
|
+
import torch
|
|
297
|
+
return torch.log1p(x)
|
|
298
|
+
|
|
299
|
+
def maximum(self, x, y):
|
|
300
|
+
"""Element-wise maximum of two arrays."""
|
|
301
|
+
import torch
|
|
302
|
+
if not isinstance(y, torch.Tensor):
|
|
303
|
+
y = torch.tensor(y, dtype=x.dtype, device=x.device)
|
|
304
|
+
return torch.maximum(x, y)
|
|
305
|
+
|
|
306
|
+
def minimum(self, x, y):
|
|
307
|
+
"""Element-wise minimum of two arrays."""
|
|
308
|
+
import torch
|
|
309
|
+
if not isinstance(y, torch.Tensor):
|
|
310
|
+
y = torch.tensor(y, dtype=x.dtype, device=x.device)
|
|
311
|
+
return torch.minimum(x, y)
|
|
312
|
+
|
|
313
|
+
def clip(self, x, min_val, max_val):
|
|
314
|
+
"""Clip values to [min_val, max_val]."""
|
|
315
|
+
import torch
|
|
316
|
+
return torch.clamp(x, min_val, max_val)
|
|
317
|
+
|
|
318
|
+
def where(self, cond, x, y):
|
|
319
|
+
"""Element-wise selection based on condition."""
|
|
320
|
+
import torch
|
|
321
|
+
return torch.where(cond, x, y)
|
|
322
|
+
|
|
323
|
+
def stack(self, arrays, axis=0):
|
|
324
|
+
"""Stack arrays along a new axis."""
|
|
325
|
+
import torch
|
|
326
|
+
return torch.stack(arrays, dim=axis)
|
|
327
|
+
|
|
328
|
+
def cat(self, arrays, axis=0):
|
|
329
|
+
"""Concatenate arrays along an axis."""
|
|
330
|
+
import torch
|
|
331
|
+
return torch.cat(arrays, dim=axis)
|
|
332
|
+
|
|
333
|
+
def diag(self, x, k=0):
|
|
334
|
+
"""Extract diagonal or create diagonal matrix."""
|
|
335
|
+
import torch
|
|
336
|
+
return torch.diag(x, diagonal=k)
|
|
337
|
+
|
|
338
|
+
def einsum(self, equation, *operands):
|
|
339
|
+
"""Einstein summation."""
|
|
340
|
+
import torch
|
|
341
|
+
return torch.einsum(equation, *operands)
|
|
342
|
+
|
|
343
|
+
def transpose(self, x, axes=None):
|
|
344
|
+
"""Transpose array."""
|
|
345
|
+
import torch
|
|
346
|
+
if axes is None:
|
|
347
|
+
return x.T
|
|
348
|
+
return x.permute(axes)
|
|
349
|
+
|
|
350
|
+
def arange(self, start, stop=None, step=1, dtype=None):
|
|
351
|
+
"""Create range array."""
|
|
352
|
+
import torch
|
|
353
|
+
if stop is None:
|
|
354
|
+
result = torch.arange(0, start, step, device=self._device)
|
|
355
|
+
else:
|
|
356
|
+
result = torch.arange(start, stop, step, device=self._device)
|
|
357
|
+
if dtype is not None:
|
|
358
|
+
result = result.to(dtype)
|
|
359
|
+
return result
|
|
360
|
+
|
|
361
|
+
def zeros(self, shape, dtype=None):
|
|
362
|
+
"""Create array of zeros."""
|
|
363
|
+
import torch
|
|
364
|
+
return torch.zeros(shape, device=self._device, dtype=dtype if dtype is not None else torch.float64)
|
|
365
|
+
|
|
366
|
+
def ones(self, shape, dtype=None):
|
|
367
|
+
"""Create array of ones."""
|
|
368
|
+
import torch
|
|
369
|
+
return torch.ones(shape, device=self._device, dtype=dtype if dtype is not None else torch.float64)
|
|
370
|
+
|
|
371
|
+
def eye(self, n, m=None, dtype=None):
|
|
372
|
+
"""Create identity matrix."""
|
|
373
|
+
import torch
|
|
374
|
+
if m is None:
|
|
375
|
+
m = n
|
|
376
|
+
return torch.eye(n, m, device=self._device, dtype=dtype if dtype is not None else torch.float64)
|
|
377
|
+
|
|
378
|
+
def full(self, shape, fill_value, dtype=None):
|
|
379
|
+
"""Create array filled with a constant value."""
|
|
380
|
+
import torch
|
|
381
|
+
if isinstance(shape, int):
|
|
382
|
+
shape = (shape,)
|
|
383
|
+
return torch.full(shape, fill_value, device=self._device, dtype=dtype if dtype is not None else torch.float64)
|
|
384
|
+
|
|
385
|
+
def array(self, val, dtype=None):
|
|
386
|
+
"""Create a scalar or array from a value."""
|
|
387
|
+
import torch
|
|
388
|
+
return torch.tensor(val, device=self._device, dtype=dtype if dtype is not None else torch.float64)
|
|
389
|
+
|
|
390
|
+
def isnan(self, x):
|
|
391
|
+
"""Element-wise isnan check."""
|
|
392
|
+
import torch
|
|
393
|
+
return torch.isnan(x)
|
|
394
|
+
|
|
395
|
+
def isinf(self, x):
|
|
396
|
+
"""Element-wise isinf check."""
|
|
397
|
+
import torch
|
|
398
|
+
return torch.isinf(x)
|
|
399
|
+
|
|
400
|
+
def nan_to_num(self, x, nan=0.0, posinf=None, neginf=None):
|
|
401
|
+
"""Replace NaN and Inf values."""
|
|
402
|
+
import torch
|
|
403
|
+
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
|
|
404
|
+
|
|
405
|
+
def matmul(self, a, b):
|
|
406
|
+
"""Matrix multiplication."""
|
|
407
|
+
import torch
|
|
408
|
+
return torch.matmul(a, b)
|
|
409
|
+
|
|
410
|
+
def min(self, x, axis=None, keepdims=False):
|
|
411
|
+
"""Minimum value along axis."""
|
|
412
|
+
import torch
|
|
413
|
+
if axis is None:
|
|
414
|
+
return torch.min(x)
|
|
415
|
+
result = torch.min(x, dim=axis)
|
|
416
|
+
if keepdims:
|
|
417
|
+
return result.values.unsqueeze(axis)
|
|
418
|
+
return result.values
|
|
419
|
+
|
|
420
|
+
def expand_dims(self, x, axis):
|
|
421
|
+
"""Expand array dimensions."""
|
|
422
|
+
import torch
|
|
423
|
+
return torch.unsqueeze(x, axis)
|
|
424
|
+
|
|
425
|
+
def eigh(self, a):
|
|
426
|
+
"""Eigenvalue decomposition for symmetric/Hermitian matrices."""
|
|
427
|
+
import torch
|
|
428
|
+
return torch.linalg.eigh(a)
|
|
429
|
+
|
|
430
|
+
def argmin(self, x, axis=None):
|
|
431
|
+
"""Indices of minimum values along axis."""
|
|
432
|
+
import torch
|
|
433
|
+
return torch.argmin(x, dim=axis)
|
|
434
|
+
|
|
435
|
+
def argmax(self, x, axis=None):
|
|
436
|
+
"""Indices of maximum values along axis."""
|
|
437
|
+
import torch
|
|
438
|
+
return torch.argmax(x, dim=axis)
|
|
439
|
+
|
|
440
|
+
def argsort(self, x, axis=-1):
|
|
441
|
+
"""Indices that would sort the array."""
|
|
442
|
+
import torch
|
|
443
|
+
return torch.argsort(x, dim=axis)
|
|
444
|
+
|
|
445
|
+
def flip(self, x, axis=None):
|
|
446
|
+
"""Reverse array order along axis."""
|
|
447
|
+
import torch
|
|
448
|
+
if axis is None:
|
|
449
|
+
return torch.flip(x, list(range(x.ndim)))
|
|
450
|
+
if isinstance(axis, int):
|
|
451
|
+
axis = [axis]
|
|
452
|
+
return torch.flip(x, axis)
|
|
453
|
+
|
|
454
|
+
def logsumexp(self, arr, axis=None):
|
|
455
|
+
"""Log-sum-exp along axis (torch-compatible)."""
|
|
456
|
+
import torch
|
|
457
|
+
if axis is None:
|
|
458
|
+
m = torch.max(arr)
|
|
459
|
+
else:
|
|
460
|
+
m = torch.max(arr, dim=axis, keepdim=True).values
|
|
461
|
+
# squeeze m to match arr shape after reduction
|
|
462
|
+
if axis is not None:
|
|
463
|
+
m_squeezed = torch.squeeze(m, dim=axis)
|
|
464
|
+
else:
|
|
465
|
+
m_squeezed = m
|
|
466
|
+
return m_squeezed + torch.log(torch.sum(torch.exp(arr - m), dim=axis))
|
|
467
|
+
|
|
468
|
+
def tensordot(self, a, b, axes=2):
|
|
469
|
+
"""Tensor dot product."""
|
|
470
|
+
import torch
|
|
471
|
+
return torch.tensordot(a, b, dims=axes)
|
|
472
|
+
|
|
473
|
+
def outer(self, a, b):
|
|
474
|
+
"""Outer product."""
|
|
475
|
+
import torch
|
|
476
|
+
return torch.outer(a.flatten(), b.flatten())
|
|
477
|
+
|
|
478
|
+
def newaxis(self):
|
|
479
|
+
"""Alias for None, used in indexing."""
|
|
480
|
+
return None
|
|
481
|
+
|
|
482
|
+
def meshgrid(self, *arrays, indexing='xy'):
|
|
483
|
+
"""Create coordinate matrices from coordinate vectors."""
|
|
484
|
+
import torch
|
|
485
|
+
return torch.meshgrid(*arrays, indexing=indexing)
|
|
486
|
+
|
|
487
|
+
def argmax(self, x, axis=None):
|
|
488
|
+
"""Return index of maximum value."""
|
|
489
|
+
import torch
|
|
490
|
+
if axis is None:
|
|
491
|
+
return torch.argmax(x)
|
|
492
|
+
return torch.argmax(x, dim=axis)
|
|
493
|
+
|
|
494
|
+
def argmin(self, x, axis=None):
|
|
495
|
+
"""Return index of minimum value."""
|
|
496
|
+
import torch
|
|
497
|
+
if axis is None:
|
|
498
|
+
return torch.argmin(x)
|
|
499
|
+
return torch.argmin(x, dim=axis)
|
|
500
|
+
|
|
501
|
+
def sort(self, x, axis=-1):
|
|
502
|
+
"""Sort array along axis."""
|
|
503
|
+
import torch
|
|
504
|
+
return torch.sort(x, dim=axis).values
|
|
505
|
+
|
|
506
|
+
def argsort(self, x, axis=-1):
|
|
507
|
+
"""Return indices that would sort array."""
|
|
508
|
+
import torch
|
|
509
|
+
return torch.argsort(x, dim=axis)
|
|
510
|
+
|
|
511
|
+
def unique(self, x, return_counts=False):
|
|
512
|
+
"""Return unique elements."""
|
|
513
|
+
import torch
|
|
514
|
+
if return_counts:
|
|
515
|
+
return torch.unique(x, return_counts=return_counts)
|
|
516
|
+
return torch.unique(x)
|
|
517
|
+
|
|
518
|
+
def any(self, x, axis=None):
|
|
519
|
+
"""Check if any element is true."""
|
|
520
|
+
import torch
|
|
521
|
+
if axis is None:
|
|
522
|
+
return torch.any(x)
|
|
523
|
+
return torch.any(x, dim=axis)
|
|
524
|
+
|
|
525
|
+
def all(self, x, axis=None):
|
|
526
|
+
"""Check if all elements are true."""
|
|
527
|
+
import torch
|
|
528
|
+
if axis is None:
|
|
529
|
+
return torch.all(x)
|
|
530
|
+
return torch.all(x, dim=axis)
|
|
531
|
+
|
|
532
|
+
def zeros_like(self, x, dtype=None):
|
|
533
|
+
"""Create zeros array with same shape as x."""
|
|
534
|
+
import torch
|
|
535
|
+
result = torch.zeros_like(x)
|
|
536
|
+
if dtype is not None:
|
|
537
|
+
result = result.to(dtype)
|
|
538
|
+
return result
|
|
539
|
+
|
|
540
|
+
def ones_like(self, x, dtype=None):
|
|
541
|
+
"""Create ones array with same shape as x."""
|
|
542
|
+
import torch
|
|
543
|
+
result = torch.ones_like(x)
|
|
544
|
+
if dtype is not None:
|
|
545
|
+
result = result.to(dtype)
|
|
546
|
+
return result
|
|
547
|
+
|
|
548
|
+
def full_like(self, x, fill_value, dtype=None):
|
|
549
|
+
"""Create filled array with same shape as x."""
|
|
550
|
+
import torch
|
|
551
|
+
result = torch.full_like(x, fill_value)
|
|
552
|
+
if dtype is not None:
|
|
553
|
+
result = result.to(dtype)
|
|
554
|
+
return result
|
|
555
|
+
|
|
556
|
+
def copy(self, x):
|
|
557
|
+
"""Return a copy of x."""
|
|
558
|
+
import torch
|
|
559
|
+
return x.clone()
|
|
560
|
+
|
|
561
|
+
def reshape(self, x, shape):
|
|
562
|
+
"""Reshape array."""
|
|
563
|
+
import torch
|
|
564
|
+
return x.reshape(shape)
|
|
565
|
+
|
|
566
|
+
def flatten(self, x):
|
|
567
|
+
"""Flatten array."""
|
|
568
|
+
import torch
|
|
569
|
+
return x.flatten()
|
|
570
|
+
|
|
571
|
+
def squeeze(self, x, axis=None):
|
|
572
|
+
"""Remove singleton dimensions."""
|
|
573
|
+
import torch
|
|
574
|
+
if axis is None:
|
|
575
|
+
return x.squeeze()
|
|
576
|
+
return x.squeeze(axis)
|
|
577
|
+
|
|
578
|
+
def expand_dims(self, x, axis):
|
|
579
|
+
"""Add singleton dimension."""
|
|
580
|
+
import torch
|
|
581
|
+
return x.unsqueeze(axis)
|
|
582
|
+
|
|
583
|
+
def atleast_1d(self, x):
|
|
584
|
+
"""Ensure array is at least 1D."""
|
|
585
|
+
import torch
|
|
586
|
+
x = torch.as_tensor(x)
|
|
587
|
+
if x.ndim == 0:
|
|
588
|
+
return x.reshape(1)
|
|
589
|
+
return x
|
|
590
|
+
|
|
591
|
+
def astype(self, x, dtype):
|
|
592
|
+
"""Cast array to dtype."""
|
|
593
|
+
import torch
|
|
594
|
+
return x.to(dtype)
|
|
595
|
+
|
|
596
|
+
def concatenate(self, arrays, axis=0):
|
|
597
|
+
"""Concatenate *arrays* along *axis* (torch.cat)."""
|
|
598
|
+
import torch
|
|
599
|
+
return torch.cat(arrays, dim=axis)
|
|
600
|
+
|
|
601
|
+
def take_along_axis(self, arr, indices, axis):
|
|
602
|
+
"""Gather elements along *axis* (torch.take_along_dim)."""
|
|
603
|
+
import torch
|
|
604
|
+
return torch.take_along_dim(arr, indices, dim=axis)
|
|
605
|
+
|
|
606
|
+
def cummin(self, arr, axis=0):
|
|
607
|
+
"""Cumulative minimum along *axis* (torch.cummin)."""
|
|
608
|
+
import torch
|
|
609
|
+
vals, _ = torch.cummin(arr, dim=axis)
|
|
610
|
+
return vals
|
|
611
|
+
|
|
612
|
+
def cummax(self, arr, axis=0):
|
|
613
|
+
"""Cumulative maximum along *axis* (torch.cummax)."""
|
|
614
|
+
import torch
|
|
615
|
+
vals, _ = torch.cummax(arr, dim=axis)
|
|
616
|
+
return vals
|
|
617
|
+
|
|
618
|
+
def flip(self, arr, axis=0):
|
|
619
|
+
"""Reverse the order of elements along *axis* (torch.flip)."""
|
|
620
|
+
import torch
|
|
621
|
+
return torch.flip(arr, dims=[axis])
|
|
622
|
+
|
|
623
|
+
@property
|
|
624
|
+
def float64(self):
|
|
625
|
+
"""float64 dtype."""
|
|
626
|
+
import torch
|
|
627
|
+
return torch.float64
|
|
628
|
+
|
|
629
|
+
@property
|
|
630
|
+
def float32(self):
|
|
631
|
+
"""float32 dtype."""
|
|
632
|
+
import torch
|
|
633
|
+
return torch.float32
|
|
634
|
+
|
|
635
|
+
@property
|
|
636
|
+
def int64(self):
|
|
637
|
+
"""int64 dtype."""
|
|
638
|
+
import torch
|
|
639
|
+
return torch.int64
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def int32(self):
|
|
643
|
+
"""int32 dtype."""
|
|
644
|
+
import torch
|
|
645
|
+
return torch.int32
|
|
646
|
+
|
|
647
|
+
@property
|
|
648
|
+
def bool(self):
|
|
649
|
+
"""bool dtype."""
|
|
650
|
+
import torch
|
|
651
|
+
return torch.bool
|
|
652
|
+
|
|
653
|
+
@property
|
|
654
|
+
def nan(self):
|
|
655
|
+
"""NaN value."""
|
|
656
|
+
import torch
|
|
657
|
+
return torch.tensor(float('nan'), dtype=torch.float64, device=self._device)
|
|
658
|
+
|
|
659
|
+
@property
|
|
660
|
+
def inf(self):
|
|
661
|
+
"""Infinity value."""
|
|
662
|
+
import torch
|
|
663
|
+
return torch.tensor(float('inf'), dtype=torch.float64, device=self._device)
|
|
664
|
+
|
|
665
|
+
@property
|
|
666
|
+
def pi(self):
|
|
667
|
+
"""Pi constant."""
|
|
668
|
+
import torch
|
|
669
|
+
return torch.tensor(3.141592653589793, dtype=torch.float64, device=self._device)
|
|
670
|
+
|
|
671
|
+
def empty_cache(self):
|
|
672
|
+
"""Clear GPU cache (Torch-specific)."""
|
|
673
|
+
import torch
|
|
674
|
+
if torch.cuda.is_available():
|
|
675
|
+
torch.cuda.empty_cache()
|
|
676
|
+
|
|
677
|
+
def count_nonzero(self, x):
|
|
678
|
+
"""Count non-zero elements."""
|
|
679
|
+
import torch
|
|
680
|
+
return torch.count_nonzero(x)
|
|
681
|
+
|
|
682
|
+
def sign(self, x):
|
|
683
|
+
"""Element-wise sign."""
|
|
684
|
+
import torch
|
|
685
|
+
return torch.sign(x)
|