statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,529 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Backend utilities for GLM loss functions.
|
|
3
|
+
|
|
4
|
+
Provides wrapper functions that dispatch to numpy/cupy/torch
|
|
5
|
+
based on the input array type, so GLM loss functions can use
|
|
6
|
+
a single code path for all backends.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from statgpu.backends._base import _resolve_backend
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _xp(arr):
|
|
15
|
+
"""Get the array module (numpy/cupy/torch) from array type."""
|
|
16
|
+
mod = type(arr).__module__
|
|
17
|
+
if mod.startswith("cupy"):
|
|
18
|
+
import cupy
|
|
19
|
+
return cupy
|
|
20
|
+
if mod.startswith("torch"):
|
|
21
|
+
import torch
|
|
22
|
+
return torch
|
|
23
|
+
import numpy
|
|
24
|
+
return numpy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _clip(arr, lo, hi):
|
|
28
|
+
"""Clip array values."""
|
|
29
|
+
xp = _xp(arr)
|
|
30
|
+
if xp.__name__ == "torch":
|
|
31
|
+
if lo is not None and hi is not None:
|
|
32
|
+
return xp.clamp(arr, min=lo, max=hi)
|
|
33
|
+
if lo is not None:
|
|
34
|
+
return xp.clamp(arr, min=lo)
|
|
35
|
+
if hi is not None:
|
|
36
|
+
return xp.clamp(arr, max=hi)
|
|
37
|
+
return arr
|
|
38
|
+
return xp.clip(arr, lo, hi)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _exp(arr):
|
|
42
|
+
"""Element-wise exponential."""
|
|
43
|
+
xp = _xp(arr)
|
|
44
|
+
return xp.exp(arr)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _log(arr):
|
|
48
|
+
"""Element-wise natural log."""
|
|
49
|
+
xp = _xp(arr)
|
|
50
|
+
return xp.log(arr)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _log1p(arr):
|
|
54
|
+
"""Element-wise log(1+x)."""
|
|
55
|
+
xp = _xp(arr)
|
|
56
|
+
return xp.log1p(arr)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _sigmoid(arr):
|
|
60
|
+
"""Numerically stable sigmoid: 1 / (1 + exp(-x))."""
|
|
61
|
+
xp = _xp(arr)
|
|
62
|
+
# float32 overflows exp() at ~89; float64 at ~709
|
|
63
|
+
dtype = getattr(arr, 'dtype', None)
|
|
64
|
+
max_val = 88.0 if dtype is not None and '32' in str(dtype) else 700.0
|
|
65
|
+
z = _clip(arr, -max_val, max_val)
|
|
66
|
+
if xp.__name__ == "torch":
|
|
67
|
+
return xp.sigmoid(z)
|
|
68
|
+
return 1.0 / (1.0 + xp.exp(-z))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _softplus(x):
|
|
72
|
+
"""Numerically stable softplus: log(1 + exp(x))."""
|
|
73
|
+
xp = _xp(x)
|
|
74
|
+
if xp.__name__ == "torch":
|
|
75
|
+
import torch.nn.functional as F
|
|
76
|
+
return F.softplus(x)
|
|
77
|
+
return xp.log1p(xp.exp(-xp.abs(x))) + _clip(x, 0.0, None)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _sum(arr):
|
|
81
|
+
"""Sum of all elements."""
|
|
82
|
+
xp = _xp(arr)
|
|
83
|
+
return xp.sum(arr)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _eigvalsh(arr):
|
|
87
|
+
"""Eigenvalues of a symmetric matrix (sorted ascending)."""
|
|
88
|
+
xp = _xp(arr)
|
|
89
|
+
return xp.linalg.eigvalsh(arr)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _zeros_like(arr):
|
|
93
|
+
"""Create zeros array with same shape/type as arr."""
|
|
94
|
+
xp = _xp(arr)
|
|
95
|
+
return xp.zeros_like(arr)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _zeros(n, backend, ref_tensor=None, dtype=None):
|
|
99
|
+
"""Create a 1-D zeros vector on the requested backend."""
|
|
100
|
+
backend = _resolve_backend(backend, ref_tensor)
|
|
101
|
+
if backend == "numpy":
|
|
102
|
+
return np.zeros(n, dtype=dtype)
|
|
103
|
+
if backend == "cupy":
|
|
104
|
+
import cupy as cp
|
|
105
|
+
out_dtype = (
|
|
106
|
+
dtype if dtype is not None else getattr(ref_tensor, "dtype", cp.float64)
|
|
107
|
+
)
|
|
108
|
+
return cp.zeros(n, dtype=out_dtype)
|
|
109
|
+
import torch
|
|
110
|
+
device = getattr(ref_tensor, "device", "cpu") if ref_tensor is not None else "cpu"
|
|
111
|
+
out_dtype = dtype or (
|
|
112
|
+
getattr(ref_tensor, "dtype", torch.float64)
|
|
113
|
+
if ref_tensor is not None
|
|
114
|
+
else torch.float64
|
|
115
|
+
)
|
|
116
|
+
return torch.zeros(n, device=device, dtype=out_dtype)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _copy_arr(arr):
|
|
120
|
+
"""Copy array: .clone() for torch, .copy() for numpy/cupy."""
|
|
121
|
+
if hasattr(arr, "clone"):
|
|
122
|
+
return arr.clone()
|
|
123
|
+
return arr.copy()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _diag(reg, backend="auto", ref_tensor=None, dtype=None):
|
|
127
|
+
"""Create a diagonal matrix on the requested backend."""
|
|
128
|
+
backend = _resolve_backend(backend, ref_tensor, reg)
|
|
129
|
+
if backend == "cupy":
|
|
130
|
+
import cupy as cp
|
|
131
|
+
out_dtype = dtype if dtype is not None else getattr(reg, "dtype", cp.float64)
|
|
132
|
+
return cp.diag(cp.asarray(reg, dtype=out_dtype))
|
|
133
|
+
if backend == "torch":
|
|
134
|
+
import torch
|
|
135
|
+
device = (
|
|
136
|
+
ref_tensor.device
|
|
137
|
+
if ref_tensor is not None
|
|
138
|
+
else getattr(reg, "device", "cpu")
|
|
139
|
+
)
|
|
140
|
+
out_dtype = dtype or (
|
|
141
|
+
ref_tensor.dtype
|
|
142
|
+
if ref_tensor is not None
|
|
143
|
+
and getattr(ref_tensor, "is_floating_point", lambda: False)()
|
|
144
|
+
else reg.dtype
|
|
145
|
+
if hasattr(reg, "is_floating_point")
|
|
146
|
+
and reg.is_floating_point()
|
|
147
|
+
else torch.float64
|
|
148
|
+
)
|
|
149
|
+
return torch.diag(torch.as_tensor(reg, dtype=out_dtype, device=device))
|
|
150
|
+
arr = np.asarray(reg, dtype=dtype) if dtype is not None else reg
|
|
151
|
+
return np.diag(arr)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _to_backend(arr, backend="auto", ref_tensor=None, dtype=None):
|
|
155
|
+
"""Convert an array to the requested backend, matching ref_tensor when needed."""
|
|
156
|
+
backend = _resolve_backend(backend, ref_tensor, arr)
|
|
157
|
+
if backend == "cupy":
|
|
158
|
+
import cupy as cp
|
|
159
|
+
out_dtype = dtype
|
|
160
|
+
if out_dtype is None:
|
|
161
|
+
ref_dtype = getattr(ref_tensor, "dtype", None)
|
|
162
|
+
if ref_dtype is not None and 'float' in str(ref_dtype):
|
|
163
|
+
out_dtype = ref_dtype
|
|
164
|
+
else:
|
|
165
|
+
out_dtype = cp.float64
|
|
166
|
+
return cp.asarray(arr, dtype=out_dtype)
|
|
167
|
+
if backend == "torch":
|
|
168
|
+
import torch
|
|
169
|
+
device = (
|
|
170
|
+
ref_tensor.device
|
|
171
|
+
if ref_tensor is not None
|
|
172
|
+
else getattr(arr, "device", "cpu")
|
|
173
|
+
)
|
|
174
|
+
out_dtype = dtype or (
|
|
175
|
+
ref_tensor.dtype
|
|
176
|
+
if ref_tensor is not None
|
|
177
|
+
and getattr(ref_tensor, "is_floating_point", lambda: False)()
|
|
178
|
+
else arr.dtype
|
|
179
|
+
if hasattr(arr, "is_floating_point")
|
|
180
|
+
and arr.is_floating_point()
|
|
181
|
+
else torch.float64
|
|
182
|
+
)
|
|
183
|
+
return torch.as_tensor(arr, dtype=out_dtype, device=device)
|
|
184
|
+
return np.asarray(arr, dtype=dtype or float)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _solve_linear_system(A, b, backend="auto"):
|
|
188
|
+
"""Solve a linear system, falling back to least squares if singular."""
|
|
189
|
+
backend = _resolve_backend(backend, A)
|
|
190
|
+
try:
|
|
191
|
+
if backend == "torch":
|
|
192
|
+
import torch
|
|
193
|
+
b_col = b.unsqueeze(1) if b.ndim == 1 else b
|
|
194
|
+
sol = torch.linalg.solve(A, b_col)
|
|
195
|
+
return sol.squeeze(1) if b.ndim == 1 else sol
|
|
196
|
+
if backend == "cupy":
|
|
197
|
+
import cupy as cp
|
|
198
|
+
return cp.linalg.solve(A, b)
|
|
199
|
+
return np.linalg.solve(A, b)
|
|
200
|
+
except (np.linalg.LinAlgError, RuntimeError):
|
|
201
|
+
# LinAlgError for numpy/cupy singular matrices
|
|
202
|
+
# RuntimeError for torch singular matrices
|
|
203
|
+
if backend == "torch":
|
|
204
|
+
import torch
|
|
205
|
+
b_col = b.unsqueeze(1) if b.ndim == 1 else b
|
|
206
|
+
sol = torch.linalg.lstsq(A, b_col).solution
|
|
207
|
+
return sol.squeeze(1) if b.ndim == 1 else sol
|
|
208
|
+
if backend == "cupy":
|
|
209
|
+
import cupy as cp
|
|
210
|
+
return cp.linalg.lstsq(A, b)[0]
|
|
211
|
+
return np.linalg.lstsq(A, b, rcond=None)[0]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _eye_like(n, ref):
|
|
215
|
+
"""Create an identity matrix on the same backend/device as ref."""
|
|
216
|
+
backend = _resolve_backend("auto", ref)
|
|
217
|
+
if backend == "cupy":
|
|
218
|
+
import cupy as cp
|
|
219
|
+
return cp.eye(n, dtype=ref.dtype)
|
|
220
|
+
if backend == "torch":
|
|
221
|
+
import torch
|
|
222
|
+
return torch.eye(n, dtype=ref.dtype, device=ref.device)
|
|
223
|
+
return np.eye(n, dtype=getattr(ref, "dtype", np.float64))
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _sync_scalars(*dev_vals, backend):
|
|
227
|
+
"""Batch device scalars into Python floats with one backend sync point."""
|
|
228
|
+
backend = _resolve_backend(backend, *dev_vals)
|
|
229
|
+
if backend == "numpy":
|
|
230
|
+
return tuple(float(v) for v in dev_vals)
|
|
231
|
+
if backend == "torch":
|
|
232
|
+
import torch
|
|
233
|
+
ref = next(
|
|
234
|
+
(
|
|
235
|
+
v
|
|
236
|
+
for v in dev_vals
|
|
237
|
+
if type(v).__module__.startswith("torch")
|
|
238
|
+
),
|
|
239
|
+
None,
|
|
240
|
+
)
|
|
241
|
+
device = getattr(ref, "device", None)
|
|
242
|
+
dtype = getattr(ref, "dtype", torch.float64)
|
|
243
|
+
stacked = torch.stack(
|
|
244
|
+
[torch.as_tensor(v, device=device, dtype=dtype) for v in dev_vals]
|
|
245
|
+
)
|
|
246
|
+
return tuple(stacked[i].item() for i in range(len(dev_vals)))
|
|
247
|
+
import cupy as cp
|
|
248
|
+
stacked = cp.stack([cp.asarray(v) for v in dev_vals])
|
|
249
|
+
return tuple(float(stacked[i]) for i in range(len(dev_vals)))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _abs_sum(x):
|
|
253
|
+
"""Sum of absolute values, returned as a Python scalar."""
|
|
254
|
+
xp = _xp(x)
|
|
255
|
+
if xp.__name__ == "torch":
|
|
256
|
+
return float(xp.sum(xp.abs(x)).item())
|
|
257
|
+
return float(xp.sum(xp.abs(x)))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _abs_max(x):
|
|
261
|
+
"""Max absolute value, returned as a Python scalar."""
|
|
262
|
+
xp = _xp(x)
|
|
263
|
+
if xp.__name__ == "torch":
|
|
264
|
+
return float(xp.max(xp.abs(x)).item())
|
|
265
|
+
return float(xp.max(xp.abs(x)))
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _norm2(x):
|
|
269
|
+
"""L2 norm, returned as a Python scalar."""
|
|
270
|
+
xp = _xp(x)
|
|
271
|
+
if xp.__name__ == "torch":
|
|
272
|
+
return float(xp.linalg.norm(x).item())
|
|
273
|
+
return float(xp.linalg.norm(x))
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _dot(a, b):
|
|
277
|
+
"""Dot product, returned as a Python scalar."""
|
|
278
|
+
val = a.dot(b)
|
|
279
|
+
return float(val.item() if hasattr(val, "item") else val)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _dot_dev(a, b):
|
|
283
|
+
"""Dot product staying on device for GPU backends."""
|
|
284
|
+
if isinstance(a, np.ndarray):
|
|
285
|
+
return float(a.dot(b))
|
|
286
|
+
return a.dot(b)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _sum_sq(x):
|
|
290
|
+
"""Sum of squares, returned as a Python scalar."""
|
|
291
|
+
xp = _xp(x)
|
|
292
|
+
val = xp.sum(x ** 2)
|
|
293
|
+
return float(val.item() if hasattr(val, "item") else val)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _sum_sq_dev(x):
|
|
297
|
+
"""Sum of squares staying on device for GPU backends."""
|
|
298
|
+
xp = _xp(x)
|
|
299
|
+
val = xp.sum(x ** 2)
|
|
300
|
+
if xp.__name__ == "numpy":
|
|
301
|
+
return float(val)
|
|
302
|
+
return val
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _norm2_dev(x):
|
|
306
|
+
"""L2 norm staying on device for GPU backends."""
|
|
307
|
+
xp = _xp(x)
|
|
308
|
+
val = xp.linalg.norm(x)
|
|
309
|
+
if xp.__name__ == "numpy":
|
|
310
|
+
return float(val)
|
|
311
|
+
return val
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _abs_sum_dev(x):
|
|
315
|
+
"""Sum of absolute values staying on device for GPU backends."""
|
|
316
|
+
xp = _xp(x)
|
|
317
|
+
val = xp.sum(xp.abs(x))
|
|
318
|
+
if xp.__name__ == "numpy":
|
|
319
|
+
return float(val)
|
|
320
|
+
return val
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _device_leq(a, b):
|
|
324
|
+
"""Device-side a <= b comparison, returned as a Python bool."""
|
|
325
|
+
backend = _resolve_backend("auto", a, b)
|
|
326
|
+
if backend == "torch":
|
|
327
|
+
return bool((a <= b).item())
|
|
328
|
+
if backend == "cupy":
|
|
329
|
+
return bool(a <= b)
|
|
330
|
+
return a <= b
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def _device_gt(a, b):
|
|
334
|
+
"""Device-side a > b comparison, returned as a Python bool."""
|
|
335
|
+
backend = _resolve_backend("auto", a, b)
|
|
336
|
+
if backend == "torch":
|
|
337
|
+
return bool((a > b).item())
|
|
338
|
+
if backend == "cupy":
|
|
339
|
+
return bool(a > b)
|
|
340
|
+
return a > b
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _clip_grad_on_device(grad, coef_old, backend):
|
|
344
|
+
"""Clip gradient entirely on the selected backend."""
|
|
345
|
+
# Lazy import to avoid circular dependency (backends <-> solvers)
|
|
346
|
+
from statgpu.solvers._constants import (
|
|
347
|
+
_GRAD_CLIP_COEF_FACTOR, _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX,
|
|
348
|
+
)
|
|
349
|
+
if backend == "numpy":
|
|
350
|
+
gn = float(np.linalg.norm(grad))
|
|
351
|
+
ca = float(np.sum(np.abs(coef_old)))
|
|
352
|
+
gmax = max(ca * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
|
|
353
|
+
if gn > gmax:
|
|
354
|
+
return grad * (gmax / gn)
|
|
355
|
+
return grad
|
|
356
|
+
if backend == "torch":
|
|
357
|
+
import torch
|
|
358
|
+
gn_sq = torch.sum(grad ** 2)
|
|
359
|
+
coef_abs = torch.sum(torch.abs(coef_old))
|
|
360
|
+
gmax = coef_abs * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR
|
|
361
|
+
gmax = torch.clamp(gmax, min=_GRAD_CLIP_MAX)
|
|
362
|
+
scale = torch.where(
|
|
363
|
+
gn_sq > gmax * gmax,
|
|
364
|
+
gmax / torch.sqrt(gn_sq + 1e-30),
|
|
365
|
+
torch.ones(1, device=grad.device, dtype=grad.dtype),
|
|
366
|
+
)
|
|
367
|
+
return grad * scale
|
|
368
|
+
import cupy as cp
|
|
369
|
+
gn_sq = cp.sum(grad ** 2)
|
|
370
|
+
coef_abs = cp.sum(cp.abs(coef_old))
|
|
371
|
+
gmax = cp.maximum(coef_abs * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
|
|
372
|
+
scale = cp.where(
|
|
373
|
+
gn_sq > gmax * gmax,
|
|
374
|
+
gmax / cp.sqrt(gn_sq + 1e-30),
|
|
375
|
+
cp.ones(1, dtype=grad.dtype),
|
|
376
|
+
)
|
|
377
|
+
return grad * scale
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _max_eigval_power(mat, n_iter=20, tol=1e-8):
|
|
381
|
+
"""Largest eigenvalue of a symmetric matrix via power iteration.
|
|
382
|
+
|
|
383
|
+
Much faster than full eigendecomposition, especially on GPU
|
|
384
|
+
where cuSOLVER eigvalsh has high kernel compilation overhead.
|
|
385
|
+
O(p^2) vs O(p^3). Accuracy within 1% for 20 iterations.
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
mat : 2-d array (p, p), symmetric positive semi-definite.
|
|
390
|
+
n_iter : int
|
|
391
|
+
Max power iterations.
|
|
392
|
+
tol : float
|
|
393
|
+
Early stopping tolerance on eigenvalue change.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
float : max eigenvalue estimate.
|
|
398
|
+
"""
|
|
399
|
+
xp = _xp(mat)
|
|
400
|
+
p = mat.shape[0]
|
|
401
|
+
dtype = getattr(mat, 'dtype', None)
|
|
402
|
+
# Build a deterministic but non-constant seed vector to avoid the
|
|
403
|
+
# pathological case where an all-ones vector is orthogonal to the top
|
|
404
|
+
# eigenspace (e.g., [[1,-1],[-1,1]]).
|
|
405
|
+
if xp.__name__ == "torch":
|
|
406
|
+
v = xp.arange(1, p + 1, dtype=dtype, device=mat.device)
|
|
407
|
+
elif dtype is not None:
|
|
408
|
+
v = xp.arange(1, p + 1, dtype=dtype)
|
|
409
|
+
else:
|
|
410
|
+
v = xp.arange(1, p + 1, dtype=xp.float64)
|
|
411
|
+
|
|
412
|
+
v_norm = xp.sqrt(xp.dot(v, v))
|
|
413
|
+
v_norm_val = float(v_norm)
|
|
414
|
+
if v_norm_val < 1e-15:
|
|
415
|
+
return 1.0
|
|
416
|
+
v = v / v_norm
|
|
417
|
+
|
|
418
|
+
if xp.__name__ == "numpy":
|
|
419
|
+
lambda_old = 0.0
|
|
420
|
+
lambda_new = 0.0
|
|
421
|
+
for _ in range(n_iter):
|
|
422
|
+
v_new = mat @ v
|
|
423
|
+
# Cache dot(v_new, v_new) to avoid recomputing mat @ v.
|
|
424
|
+
nv2 = xp.dot(v_new, v_new)
|
|
425
|
+
v_norm_sq = float(nv2)
|
|
426
|
+
if v_norm_sq < 1e-30:
|
|
427
|
+
return 1.0
|
|
428
|
+
v_norm = v_norm_sq ** 0.5
|
|
429
|
+
v = v_new / v_norm
|
|
430
|
+
# lambda = v^T A v = v^T v_new (v_new = A v, already computed)
|
|
431
|
+
lambda_new = float(xp.dot(v, v_new))
|
|
432
|
+
if lambda_old > 0 and abs(lambda_new - lambda_old) < tol * abs(lambda_new):
|
|
433
|
+
break
|
|
434
|
+
lambda_old = lambda_new
|
|
435
|
+
return lambda_new
|
|
436
|
+
|
|
437
|
+
lambda_old = 0.0
|
|
438
|
+
lambda_val = 0.0
|
|
439
|
+
for i in range(n_iter):
|
|
440
|
+
v_new = mat @ v
|
|
441
|
+
dot_vn_vn = xp.dot(v_new, v_new)
|
|
442
|
+
v_norm_sq = float(dot_vn_vn.item() if hasattr(dot_vn_vn, "item") else dot_vn_vn)
|
|
443
|
+
if v_norm_sq < 1e-30:
|
|
444
|
+
return 1.0 # Zero matrix — same fallback as numpy path
|
|
445
|
+
v_norm = v_norm_sq ** 0.5
|
|
446
|
+
v = v_new / v_norm
|
|
447
|
+
lambda_new = xp.dot(v, v_new)
|
|
448
|
+
lambda_val = float(lambda_new.item() if hasattr(lambda_new, "item") else lambda_new)
|
|
449
|
+
if i > 0 and abs(lambda_val - lambda_old) < tol * abs(lambda_val):
|
|
450
|
+
return lambda_val
|
|
451
|
+
lambda_old = lambda_val
|
|
452
|
+
return lambda_val
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _soft_threshold(w, thresh):
|
|
456
|
+
"""Soft-thresholding operator: sign(w) * max(|w| - thresh, 0).
|
|
457
|
+
|
|
458
|
+
Works across numpy/cupy/torch. ``thresh`` may be a scalar or an
|
|
459
|
+
array with the same shape as ``w`` (adaptive weights).
|
|
460
|
+
|
|
461
|
+
Uses ``xp.where`` for fewer intermediate arrays (2 vs 4 with
|
|
462
|
+
sign*clip formulation).
|
|
463
|
+
"""
|
|
464
|
+
xp = _xp(w)
|
|
465
|
+
abs_w = xp.abs(w)
|
|
466
|
+
# +0.0 eliminates negative zeros from sign(w)
|
|
467
|
+
return (xp.where(abs_w > thresh, abs_w - thresh, 0.0) * xp.sign(w)) + 0.0
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _scalar_tensor(val, ref_arr):
|
|
471
|
+
"""Create a scalar value compatible with *ref_arr*'s backend/device.
|
|
472
|
+
|
|
473
|
+
For torch, returns a 0-d tensor on the same device and dtype.
|
|
474
|
+
For cupy/numpy, returns a plain Python float (scalars work directly).
|
|
475
|
+
"""
|
|
476
|
+
xp = _xp(ref_arr)
|
|
477
|
+
if xp.__name__ == "torch":
|
|
478
|
+
import torch
|
|
479
|
+
return torch.tensor(val, dtype=ref_arr.dtype, device=ref_arr.device)
|
|
480
|
+
return float(val)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _xp_copy(arr):
|
|
484
|
+
"""Copy array on the same backend. `.clone()` for torch, `.copy()` for others."""
|
|
485
|
+
xp = _xp(arr)
|
|
486
|
+
if xp.__name__ == "torch":
|
|
487
|
+
return arr.clone()
|
|
488
|
+
return arr.copy()
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def _xp_zeros(shape, dtype, ref_arr):
|
|
492
|
+
"""Create zeros array on the same device/dtype as *ref_arr*."""
|
|
493
|
+
xp = _xp(ref_arr)
|
|
494
|
+
if xp.__name__ == "torch":
|
|
495
|
+
import torch
|
|
496
|
+
return torch.zeros(shape, dtype=dtype or ref_arr.dtype, device=ref_arr.device)
|
|
497
|
+
return xp.zeros(shape, dtype=dtype or getattr(ref_arr, 'dtype', None))
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _xp_asarray(arr, dtype, ref_arr):
|
|
501
|
+
"""Convert array to the same backend/device as *ref_arr*.
|
|
502
|
+
|
|
503
|
+
Handles numpy→cupy, numpy→torch, and same-backend dtype casts.
|
|
504
|
+
"""
|
|
505
|
+
xp = _xp(ref_arr)
|
|
506
|
+
if xp.__name__ == "torch":
|
|
507
|
+
import torch
|
|
508
|
+
if isinstance(arr, torch.Tensor):
|
|
509
|
+
out = arr.to(dtype=dtype, device=ref_arr.device)
|
|
510
|
+
else:
|
|
511
|
+
out = torch.as_tensor(np.asarray(arr, dtype=np.float64),
|
|
512
|
+
dtype=dtype, device=ref_arr.device)
|
|
513
|
+
return out
|
|
514
|
+
if xp.__name__ == "cupy":
|
|
515
|
+
# Convert torch dtypes to numpy for cupy compatibility
|
|
516
|
+
if hasattr(dtype, '__module__') and 'torch' in str(getattr(dtype, '__module__', '')):
|
|
517
|
+
from statgpu.backends._utils import _torch_dtype_to_np
|
|
518
|
+
dtype = _torch_dtype_to_np(dtype)
|
|
519
|
+
return xp.asarray(arr, dtype=dtype)
|
|
520
|
+
return np.asarray(arr, dtype=dtype)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def _xp_eye(n, dtype, ref_arr):
|
|
524
|
+
"""Create identity matrix on the same device/dtype as *ref_arr*."""
|
|
525
|
+
xp = _xp(ref_arr)
|
|
526
|
+
if xp.__name__ == "torch":
|
|
527
|
+
import torch
|
|
528
|
+
return torch.eye(n, dtype=dtype or ref_arr.dtype, device=ref_arr.device)
|
|
529
|
+
return xp.eye(n, dtype=dtype or getattr(ref_arr, 'dtype', None))
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract base class for compute backends.
|
|
3
|
+
|
|
4
|
+
A backend wraps an array library (NumPy, CuPy, or PyTorch) and exposes a
|
|
5
|
+
uniform interface so that model implementations can stay array-library agnostic.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
# Array-type detection helpers (deferred imports to avoid hard deps)
|
|
16
|
+
# ---------------------------------------------------------------------------
|
|
17
|
+
|
|
18
|
+
def _is_cupy_array(x: Any) -> bool:
|
|
19
|
+
"""Return True if *x* is a CuPy ndarray."""
|
|
20
|
+
try:
|
|
21
|
+
import cupy as cp
|
|
22
|
+
return isinstance(x, cp.ndarray)
|
|
23
|
+
except Exception:
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _is_torch_array(x: Any) -> bool:
|
|
28
|
+
"""Return True if *x* is a PyTorch Tensor."""
|
|
29
|
+
try:
|
|
30
|
+
import torch
|
|
31
|
+
return isinstance(x, torch.Tensor)
|
|
32
|
+
except Exception:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _resolve_backend(backend: str, *arrays) -> str:
|
|
37
|
+
"""Resolve the named *backend* string to one of ``'numpy'``, ``'cupy'``,
|
|
38
|
+
``'torch'``.
|
|
39
|
+
|
|
40
|
+
Accepts legacy aliases ``'cpu'`` → ``'numpy'`` and ``'cuda'``/``'gpu'`` → ``'cupy'``.
|
|
41
|
+
When *backend* is ``'auto'``, inspect *arrays* and return the
|
|
42
|
+
matching backend name based on the first recognised array type.
|
|
43
|
+
Falls back to ``'numpy'`` when no array matches.
|
|
44
|
+
"""
|
|
45
|
+
backend_name = str(backend).strip().lower()
|
|
46
|
+
backend_name = {"cpu": "numpy", "cuda": "cupy", "gpu": "cupy"}.get(
|
|
47
|
+
backend_name, backend_name
|
|
48
|
+
)
|
|
49
|
+
if backend_name not in ("auto", "numpy", "cupy", "torch"):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"backend must be one of: 'auto', 'numpy', 'cupy', 'torch', "
|
|
52
|
+
"or legacy aliases 'cpu', 'cuda', 'gpu'"
|
|
53
|
+
)
|
|
54
|
+
if backend_name != "auto":
|
|
55
|
+
return backend_name
|
|
56
|
+
|
|
57
|
+
for arr in arrays:
|
|
58
|
+
if arr is not None:
|
|
59
|
+
if _is_torch_array(arr):
|
|
60
|
+
return "torch"
|
|
61
|
+
if _is_cupy_array(arr):
|
|
62
|
+
return "cupy"
|
|
63
|
+
return "numpy"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class BackendBase(ABC):
|
|
67
|
+
"""
|
|
68
|
+
Abstract base for compute backends.
|
|
69
|
+
|
|
70
|
+
Subclasses wrap a specific array library and expose:
|
|
71
|
+
|
|
72
|
+
* ``xp`` – the underlying array module (numpy / cupy / torch).
|
|
73
|
+
* ``asarray`` – convert arbitrary inputs to the backend's native array.
|
|
74
|
+
* ``to_numpy`` – convert the backend's arrays back to ``numpy.ndarray``.
|
|
75
|
+
* ``is_available`` – runtime check for the library being usable.
|
|
76
|
+
|
|
77
|
+
The ``xp`` object follows the NumPy array API so that operations such as
|
|
78
|
+
``xp.linalg.solve``, ``xp.sum``, ``xp.exp`` etc. work without
|
|
79
|
+
library-specific branches in the calling code.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
#: Short name used in repr and config ('numpy', 'cupy', 'torch').
|
|
83
|
+
name: str = ""
|
|
84
|
+
|
|
85
|
+
# ------------------------------------------------------------------
|
|
86
|
+
# Abstract interface
|
|
87
|
+
# ------------------------------------------------------------------
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def xp(self) -> Any:
|
|
92
|
+
"""Return the array module (numpy / cupy / torch)."""
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def asarray(self, x, dtype=None) -> Any:
|
|
96
|
+
"""
|
|
97
|
+
Convert *x* to this backend's native array type.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
x : array-like, numpy.ndarray, cupy.ndarray, or torch.Tensor
|
|
102
|
+
Input data.
|
|
103
|
+
dtype : dtype-like, optional
|
|
104
|
+
Desired data type.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
array
|
|
109
|
+
Native array on the backend's device.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def to_numpy(self, x) -> np.ndarray:
|
|
114
|
+
"""
|
|
115
|
+
Convert *x* to a ``numpy.ndarray``.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
x : array-like
|
|
120
|
+
A native array produced by this backend (or any array-like).
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
numpy.ndarray
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
@abstractmethod
|
|
128
|
+
def is_available(self) -> bool:
|
|
129
|
+
"""Return True if this backend can be used in the current environment."""
|
|
130
|
+
|
|
131
|
+
# ------------------------------------------------------------------
|
|
132
|
+
# Convenience helpers (non-abstract, built on top of xp)
|
|
133
|
+
# ------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
def solve(self, A, b):
|
|
136
|
+
"""Solve the linear system *Ax = b*."""
|
|
137
|
+
return self.xp.linalg.solve(A, b)
|
|
138
|
+
|
|
139
|
+
def lstsq(self, A, b, rcond=None):
|
|
140
|
+
"""Return the least-squares solution to *Ax ≈ b*."""
|
|
141
|
+
return self.xp.linalg.lstsq(A, b, rcond=rcond)
|
|
142
|
+
|
|
143
|
+
def astype(self, arr, dtype):
|
|
144
|
+
"""Cast *arr* to *dtype* (backend-agnostic .astype / .to)."""
|
|
145
|
+
return arr.astype(dtype)
|
|
146
|
+
|
|
147
|
+
def concatenate(self, arrays, axis=0):
|
|
148
|
+
"""Concatenate *arrays* along *axis* (.concatenate / .cat)."""
|
|
149
|
+
return self.xp.concatenate(arrays, axis=axis)
|
|
150
|
+
|
|
151
|
+
def take_along_axis(self, arr, indices, axis):
|
|
152
|
+
"""Gather elements along *axis* (.take_along_axis / .take_along_dim)."""
|
|
153
|
+
return self.xp.take_along_axis(arr, indices, axis=axis)
|
|
154
|
+
|
|
155
|
+
def cummin(self, arr, axis=0):
|
|
156
|
+
"""Cumulative minimum along *axis*."""
|
|
157
|
+
return self.xp.minimum.accumulate(arr, axis=axis)
|
|
158
|
+
|
|
159
|
+
def cummax(self, arr, axis=0):
|
|
160
|
+
"""Cumulative maximum along *axis*."""
|
|
161
|
+
return self.xp.maximum.accumulate(arr, axis=axis)
|
|
162
|
+
|
|
163
|
+
def flip(self, arr, axis=0):
|
|
164
|
+
"""Reverse the order of elements along *axis*."""
|
|
165
|
+
return self.xp.flip(arr, axis=axis)
|
|
166
|
+
|
|
167
|
+
def copy(self, arr):
|
|
168
|
+
"""Return a copy of *arr*."""
|
|
169
|
+
return arr.copy()
|
|
170
|
+
|
|
171
|
+
def reshape(self, arr, shape):
|
|
172
|
+
"""Reshape *arr* to *shape*."""
|
|
173
|
+
return arr.reshape(shape)
|
|
174
|
+
|
|
175
|
+
def logsumexp(self, arr, axis=None):
|
|
176
|
+
"""Log-sum-exp along *axis*."""
|
|
177
|
+
import numpy as np
|
|
178
|
+
xp = self.xp
|
|
179
|
+
m = xp.max(arr, axis=axis, keepdims=True)
|
|
180
|
+
return xp.squeeze(m, axis=axis) + xp.log(xp.sum(xp.exp(arr - m), axis=axis))
|
|
181
|
+
|
|
182
|
+
def __repr__(self) -> str:
|
|
183
|
+
available = "available" if self.is_available() else "unavailable"
|
|
184
|
+
return f"{self.__class__.__name__}(name={self.name!r}, {available})"
|