statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2610 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified probability-distribution backend.
|
|
3
|
+
|
|
4
|
+
Supports ``numpy``, ``cupy``, and ``torch`` backends through a single
|
|
5
|
+
``SpecialFunctions`` protocol, eliminating code duplication across
|
|
6
|
+
``_distributions_gpu.py`` and ``_distributions_torch.py``.
|
|
7
|
+
|
|
8
|
+
Usage::
|
|
9
|
+
|
|
10
|
+
from statgpu.inference._distributions_backend import get_distribution, norm, t
|
|
11
|
+
|
|
12
|
+
# Explicit backend
|
|
13
|
+
norm_dist = get_distribution("norm", backend="numpy")
|
|
14
|
+
norm_dist.cdf([0.0, 1.0, 2.0])
|
|
15
|
+
|
|
16
|
+
# Module-level proxy with auto backend detection
|
|
17
|
+
norm.cdf([0.0, 1.0, 2.0])
|
|
18
|
+
t.cdf(1.5, df=10, backend="cupy")
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
from abc import abstractmethod
|
|
25
|
+
from functools import lru_cache
|
|
26
|
+
from typing import Any, Protocol, runtime_checkable
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
|
|
30
|
+
from statgpu.backends import _get_torch_device_str as _get_torch_device
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# SpecialFunctions protocol — abstracts away library-specific special functions
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
@runtime_checkable
|
|
38
|
+
class SpecialFunctions(Protocol):
|
|
39
|
+
"""Protocol for special-function providers.
|
|
40
|
+
|
|
41
|
+
Implementations: ``CuPySpecialFunctions``, ``TorchSpecialFunctions``,
|
|
42
|
+
``ScipySpecialFunctions``.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def betainc(self, a, b, x):
|
|
47
|
+
"""Regularized incomplete beta I_x(a, b)."""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def betaincinv(self, a, b, y):
|
|
51
|
+
"""Inverse regularized incomplete beta."""
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def gammainc(self, a, x):
|
|
55
|
+
"""Regularized lower incomplete gamma P(a, x)."""
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def gammaincc(self, a, x):
|
|
59
|
+
"""Regularized upper incomplete gamma Q(a, x)."""
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def gammaincinv(self, a, q):
|
|
63
|
+
"""Inverse regularized lower incomplete gamma."""
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def gammaln(self, x):
|
|
67
|
+
"""Log-gamma."""
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def erf(self, x):
|
|
71
|
+
"""Error function."""
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def erfc(self, x):
|
|
75
|
+
"""Complementary error function."""
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def erfcinv(self, y):
|
|
79
|
+
"""Inverse complementary error function."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# =============================================================================
|
|
83
|
+
# CuPy backend
|
|
84
|
+
# =============================================================================
|
|
85
|
+
|
|
86
|
+
class CuPySpecialFunctions:
|
|
87
|
+
"""Special functions via cupyx.scipy.special with LUT acceleration.
|
|
88
|
+
|
|
89
|
+
Inverse special functions (betaincinv, gammaincinv) use GPU-resident LUT
|
|
90
|
+
+ 1-step Newton refinement for ~10-100x speedup over cupyx iterative solver.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, *, use_lut: bool = True):
|
|
94
|
+
import cupy as cp
|
|
95
|
+
import cupyx.scipy.special as csp
|
|
96
|
+
self._cp = cp
|
|
97
|
+
self._csp = csp
|
|
98
|
+
self.use_lut = use_lut
|
|
99
|
+
# LUT caches for inverse special functions (instance-level)
|
|
100
|
+
self._betaincinv_lut = {}
|
|
101
|
+
self._gammaincinv_lut = {}
|
|
102
|
+
|
|
103
|
+
def betainc(self, a, b, x):
|
|
104
|
+
return self._csp.betainc(a, b, self._cp.asarray(x, dtype=self._cp.float64))
|
|
105
|
+
|
|
106
|
+
def betaincinv(self, a, b, y):
|
|
107
|
+
cp = self._cp
|
|
108
|
+
yt = cp.asarray(y, dtype=cp.float64)
|
|
109
|
+
try:
|
|
110
|
+
af, bf = float(a), float(b)
|
|
111
|
+
except (TypeError, ValueError):
|
|
112
|
+
return self._csp.betaincinv(a, b, yt)
|
|
113
|
+
if not self.use_lut:
|
|
114
|
+
return self._csp.betaincinv(a, b, yt)
|
|
115
|
+
if af < 0.3 or bf < 0.3 or af > 50 or bf > 50 or abs(af - bf) > 30:
|
|
116
|
+
return self._csp.betaincinv(a, b, yt)
|
|
117
|
+
key = (af, bf)
|
|
118
|
+
if key not in self._betaincinv_lut:
|
|
119
|
+
x_grid, y_grid = self._build_betaincinv_lut(af, bf, 20000)
|
|
120
|
+
self._betaincinv_lut[key] = (cp.asarray(x_grid), cp.asarray(y_grid))
|
|
121
|
+
yg, xg = self._betaincinv_lut[key]
|
|
122
|
+
idx = cp.searchsorted(yg, cp.clip(yt, 1e-15, 1.0 - 1e-15)).clip(1, len(yg) - 1)
|
|
123
|
+
y0, y1 = yg[idx - 1], yg[idx]
|
|
124
|
+
x0, x1 = xg[idx - 1], xg[idx]
|
|
125
|
+
w = (yt - y0) / (y1 - y0 + 1e-300)
|
|
126
|
+
x = cp.clip(x0 + w * (x1 - x0), 1e-10, 1.0 - 1e-10)
|
|
127
|
+
# 1-step Newton refine using cupyx betainc
|
|
128
|
+
import math as _math
|
|
129
|
+
log_beta = _math.lgamma(af) + _math.lgamma(bf) - _math.lgamma(af + bf)
|
|
130
|
+
p = self._csp.betainc(af, bf, x)
|
|
131
|
+
diff = p - yt
|
|
132
|
+
log_deriv = (af - 1.0) * cp.log(cp.clip(x, 1e-300, None)) + \
|
|
133
|
+
(bf - 1.0) * cp.log(cp.clip(1.0 - x, 1e-300, None)) - log_beta
|
|
134
|
+
deriv = cp.exp(log_deriv)
|
|
135
|
+
x1 = x - diff / cp.clip(deriv, 1e-300, 1e300)
|
|
136
|
+
return cp.clip(x1, 1e-15, 1.0 - 1e-15)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _build_betaincinv_lut(a, b, n_grid):
|
|
140
|
+
"""Build LUT via scipy on CPU, returns (x_grid, y_grid) as numpy arrays.
|
|
141
|
+
|
|
142
|
+
Uses log spacing near both boundaries for better precision when
|
|
143
|
+
a or b is small (e.g. b=0.5 for t/f distributions).
|
|
144
|
+
"""
|
|
145
|
+
import scipy.special as _scsp
|
|
146
|
+
eps = 1e-15
|
|
147
|
+
n_edge = int(n_grid * 0.4)
|
|
148
|
+
n_mid = n_grid - 2 * n_edge
|
|
149
|
+
x_lo = np.logspace(np.log10(eps), np.log10(0.01), n_edge)
|
|
150
|
+
x_mid = np.linspace(0.01, 0.99, n_mid + 2)[1:-1]
|
|
151
|
+
x_hi = 1.0 - np.logspace(np.log10(eps), np.log10(0.01), n_edge)[::-1]
|
|
152
|
+
x_grid = np.concatenate([x_lo, x_mid, x_hi])
|
|
153
|
+
if len(x_grid) > n_grid:
|
|
154
|
+
x_grid = x_grid[:n_grid]
|
|
155
|
+
y_grid = _scsp.betainc(a, b, x_grid)
|
|
156
|
+
return x_grid, y_grid
|
|
157
|
+
|
|
158
|
+
def gammainc(self, a, x):
|
|
159
|
+
return self._csp.gammainc(
|
|
160
|
+
self._cp.asarray(a, dtype=self._cp.float64),
|
|
161
|
+
self._cp.asarray(x, dtype=self._cp.float64),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def gammaincc(self, a, x):
|
|
165
|
+
return self._csp.gammaincc(
|
|
166
|
+
self._cp.asarray(a, dtype=self._cp.float64),
|
|
167
|
+
self._cp.asarray(x, dtype=self._cp.float64),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def gammaincinv(self, a, q):
|
|
171
|
+
cp = self._cp
|
|
172
|
+
qt = cp.asarray(q, dtype=cp.float64)
|
|
173
|
+
try:
|
|
174
|
+
af = float(a)
|
|
175
|
+
except (TypeError, ValueError):
|
|
176
|
+
return self._csp.gammaincinv(cp.asarray(a, dtype=cp.float64), qt)
|
|
177
|
+
if not self.use_lut:
|
|
178
|
+
return self._csp.gammaincinv(cp.asarray(a, dtype=cp.float64), qt)
|
|
179
|
+
if af < 1.0:
|
|
180
|
+
return self._csp.gammaincinv(cp.asarray(a, dtype=cp.float64), qt)
|
|
181
|
+
key = (af,)
|
|
182
|
+
if key not in self._gammaincinv_lut:
|
|
183
|
+
x_grid, y_grid = self._build_gammaincinv_lut(af, 20000)
|
|
184
|
+
self._gammaincinv_lut[key] = (cp.asarray(x_grid), cp.asarray(y_grid))
|
|
185
|
+
yg, xg = self._gammaincinv_lut[key]
|
|
186
|
+
idx = cp.searchsorted(yg, cp.clip(qt, 1e-15, 1.0 - 1e-15)).clip(1, len(yg) - 1)
|
|
187
|
+
y0, y1 = yg[idx - 1], yg[idx]
|
|
188
|
+
x0, x1 = xg[idx - 1], xg[idx]
|
|
189
|
+
w = (qt - y0) / (y1 - y0 + 1e-300)
|
|
190
|
+
x = cp.clip(x0 + w * (x1 - x0), 1e-15, 1e6)
|
|
191
|
+
# 1-step Newton refine using cupyx gammainc
|
|
192
|
+
import math as _math
|
|
193
|
+
log_ga = _math.lgamma(af)
|
|
194
|
+
p = self._csp.gammainc(af, x)
|
|
195
|
+
diff = p - qt
|
|
196
|
+
log_deriv = (af - 1.0) * cp.log(cp.clip(x, 1e-300, None)) - x - log_ga
|
|
197
|
+
deriv = cp.exp(log_deriv)
|
|
198
|
+
x1 = x - diff / cp.clip(deriv, 1e-300, 1e300)
|
|
199
|
+
return cp.clip(x1, 1e-15, 1e6)
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def _build_gammaincinv_lut(a, n_grid):
|
|
203
|
+
"""Build LUT via scipy on CPU, returns (x_grid, y_grid) as numpy arrays."""
|
|
204
|
+
import math
|
|
205
|
+
import scipy.special as _scsp
|
|
206
|
+
x_max = a + 20 * math.sqrt(max(a, 0.1)) + 10
|
|
207
|
+
x_max = min(x_max, 1e6)
|
|
208
|
+
n_log = n_grid // 3
|
|
209
|
+
n_lin = n_grid - n_log
|
|
210
|
+
x_lo = np.logspace(-15, math.log10(max(x_max, 1e-10)), n_log, endpoint=False)
|
|
211
|
+
x_hi = np.linspace(x_lo[-1] if len(x_lo) > 0 else 0, x_max, n_lin + 1)[1:]
|
|
212
|
+
x_grid = np.concatenate([x_lo, x_hi])
|
|
213
|
+
if len(x_grid) < n_grid:
|
|
214
|
+
extra = np.linspace(x_grid[-1], x_max, n_grid - len(x_grid) + 2)[1:]
|
|
215
|
+
x_grid = np.concatenate([x_grid, extra])
|
|
216
|
+
x_grid = x_grid[:n_grid]
|
|
217
|
+
y_grid = _scsp.gammainc(a, x_grid)
|
|
218
|
+
y_grid[0] = 0.0
|
|
219
|
+
y_grid[-1] = 1.0
|
|
220
|
+
return x_grid, y_grid
|
|
221
|
+
|
|
222
|
+
def gammaln(self, x):
|
|
223
|
+
return self._csp.gammaln(self._cp.asarray(x, dtype=self._cp.float64))
|
|
224
|
+
|
|
225
|
+
def erf(self, x):
|
|
226
|
+
return self._csp.erf(self._cp.asarray(x, dtype=self._cp.float64))
|
|
227
|
+
|
|
228
|
+
def erfc(self, x):
|
|
229
|
+
return self._csp.erfc(self._cp.asarray(x, dtype=self._cp.float64))
|
|
230
|
+
|
|
231
|
+
def erfcinv(self, y):
|
|
232
|
+
return self._csp.erfcinv(self._cp.asarray(y, dtype=self._cp.float64))
|
|
233
|
+
|
|
234
|
+
def sqrt(self, x):
|
|
235
|
+
return self._cp.sqrt(self._cp.asarray(x, dtype=self._cp.float64))
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def pi(self):
|
|
239
|
+
return self._cp.pi
|
|
240
|
+
|
|
241
|
+
def clip(self, x, lo, hi):
|
|
242
|
+
return self._cp.clip(x, lo, hi)
|
|
243
|
+
|
|
244
|
+
def where(self, cond, x, y):
|
|
245
|
+
return self._cp.where(cond, x, y)
|
|
246
|
+
|
|
247
|
+
def as_float64(self, x):
|
|
248
|
+
return self._cp.asarray(x, dtype=self._cp.float64)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# =============================================================================
|
|
252
|
+
# Torch backend
|
|
253
|
+
# =============================================================================
|
|
254
|
+
|
|
255
|
+
# Module-level cache for torch betaincinv inverse LUTs (scalar a, b)
|
|
256
|
+
# Key: (a, b, device) -> (y_grid, x_grid) tensors on device
|
|
257
|
+
_torch_betaincinv_lut_cache: dict = {}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# Module-level cache for torch betainc forward LUTs (scalar a, b)
|
|
261
|
+
# Key: (a, b, device) -> (x_grid, y_grid) tensors on device
|
|
262
|
+
_torch_betainc_lut_cache: dict = {}
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _get_torch_betaincinv_lut(a, b, device, n_points=20000):
|
|
266
|
+
"""Build a GPU-resident inverse LUT for betaincinv(a, b, y).
|
|
267
|
+
|
|
268
|
+
Precomputes x = betaincinv(a, b, y) for 20K y values via scipy on CPU
|
|
269
|
+
(one-time cost, <200ms) then uses searchsorted for O(log n) lookup.
|
|
270
|
+
"""
|
|
271
|
+
from scipy import special as _scsp
|
|
272
|
+
import torch
|
|
273
|
+
|
|
274
|
+
cache_key = (a, b, device)
|
|
275
|
+
if cache_key in _torch_betaincinv_lut_cache:
|
|
276
|
+
return _torch_betaincinv_lut_cache[cache_key]
|
|
277
|
+
|
|
278
|
+
y_vals = np.linspace(1e-15, 1.0 - 1e-15, n_points)
|
|
279
|
+
x_vals = _scsp.betaincinv(a, b, y_vals)
|
|
280
|
+
y_grid = torch.as_tensor(y_vals, dtype=torch.float64, device=device)
|
|
281
|
+
x_grid = torch.as_tensor(x_vals, dtype=torch.float64, device=device)
|
|
282
|
+
_torch_betaincinv_lut_cache[cache_key] = (y_grid, x_grid)
|
|
283
|
+
return y_grid, x_grid
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _get_torch_betainc_lut(a, b, device, n_points=40000):
|
|
287
|
+
"""Build a GPU-resident forward LUT for betainc(a, b, x).
|
|
288
|
+
|
|
289
|
+
Precomputes y = betainc(a, b, x) for 40K x values via scipy on CPU
|
|
290
|
+
(one-time cost, <50ms) then uses searchsorted for O(log n) lookup.
|
|
291
|
+
Uses log spacing near boundaries for better precision when a or b is small.
|
|
292
|
+
"""
|
|
293
|
+
from scipy import special as _scsp
|
|
294
|
+
import torch
|
|
295
|
+
|
|
296
|
+
cache_key = (a, b, device)
|
|
297
|
+
if cache_key in _torch_betainc_lut_cache:
|
|
298
|
+
return _torch_betainc_lut_cache[cache_key]
|
|
299
|
+
|
|
300
|
+
# Log spacing near boundaries for b < 1 singularity
|
|
301
|
+
eps = 1e-15
|
|
302
|
+
n_edge = int(n_points * 0.4)
|
|
303
|
+
n_mid = n_points - 2 * n_edge
|
|
304
|
+
x_lo = np.logspace(np.log10(eps), np.log10(0.01), n_edge)
|
|
305
|
+
x_mid = np.linspace(0.01, 0.99, n_mid + 2)[1:-1]
|
|
306
|
+
x_hi = 1.0 - np.logspace(np.log10(eps), np.log10(0.01), n_edge)[::-1]
|
|
307
|
+
x_vals = np.concatenate([x_lo, x_mid, x_hi])[:n_points]
|
|
308
|
+
y_vals = _scsp.betainc(a, b, x_vals)
|
|
309
|
+
x_grid = torch.as_tensor(x_vals, dtype=torch.float64, device=device)
|
|
310
|
+
y_grid = torch.as_tensor(y_vals, dtype=torch.float64, device=device)
|
|
311
|
+
_torch_betainc_lut_cache[cache_key] = (x_grid, y_grid)
|
|
312
|
+
return x_grid, y_grid
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class TorchSpecialFunctions:
|
|
317
|
+
"""Special functions via torch.special with fallbacks for missing functions."""
|
|
318
|
+
|
|
319
|
+
def __init__(self, device: str | None = None, *, use_lut: bool = True):
|
|
320
|
+
import torch
|
|
321
|
+
self._torch = torch
|
|
322
|
+
self._device = device or _get_torch_device()
|
|
323
|
+
self.use_lut = use_lut
|
|
324
|
+
|
|
325
|
+
def _as_tensor(self, x):
|
|
326
|
+
return self._torch.as_tensor(x, dtype=self._torch.float64, device=self._device)
|
|
327
|
+
|
|
328
|
+
# ── betainc fallback ───────────────────────────────────────────
|
|
329
|
+
def betainc(self, a, b, x):
|
|
330
|
+
t = self._torch
|
|
331
|
+
# Check if torch has native betainc (>= 1.8)
|
|
332
|
+
if hasattr(t.special, "betainc"):
|
|
333
|
+
return t.special.betainc(
|
|
334
|
+
self._as_tensor(a), self._as_tensor(b), self._as_tensor(x),
|
|
335
|
+
)
|
|
336
|
+
# LUT-based betainc for scalar a, b (major speedup for binom)
|
|
337
|
+
try:
|
|
338
|
+
af, bf = float(a), float(b)
|
|
339
|
+
except (TypeError, ValueError):
|
|
340
|
+
pass # fall through to element-wise loop below
|
|
341
|
+
else:
|
|
342
|
+
if self.use_lut:
|
|
343
|
+
try:
|
|
344
|
+
xg, yg = _get_torch_betainc_lut(af, bf, self._device)
|
|
345
|
+
xt = self._as_tensor(x)
|
|
346
|
+
xt_clamp = t.clamp(xt, 0.0, 1.0)
|
|
347
|
+
idx = t.searchsorted(xg, xt_clamp).clip(1, len(xg) - 1)
|
|
348
|
+
x0, x1 = xg[idx - 1], xg[idx]
|
|
349
|
+
y0, y1 = yg[idx - 1], yg[idx]
|
|
350
|
+
w = (xt_clamp - x0) / (x1 - x0 + 1e-300)
|
|
351
|
+
return (y0 + w * (y1 - y0)).clamp(0.0, 1.0).view_as(xt)
|
|
352
|
+
except Exception:
|
|
353
|
+
pass # fall through to integral fallback
|
|
354
|
+
return self._betainc_integral(af, bf, self._as_tensor(x))
|
|
355
|
+
# Non-scalar a or b — grouped LUT lookup (avoids element-wise Chebyshev integral)
|
|
356
|
+
try:
|
|
357
|
+
return self._betainc_batch(a, b, x)
|
|
358
|
+
except Exception:
|
|
359
|
+
# Full fallback: compute on CPU via scipy
|
|
360
|
+
try:
|
|
361
|
+
import scipy.special as _scsp
|
|
362
|
+
a_np = np.asarray(self._as_tensor(a).cpu().numpy())
|
|
363
|
+
b_np = np.asarray(self._as_tensor(b).cpu().numpy())
|
|
364
|
+
x_np = np.asarray(self._as_tensor(x).cpu().numpy())
|
|
365
|
+
result = _scsp.betainc(
|
|
366
|
+
np.clip(a_np, 1, None).astype(int),
|
|
367
|
+
np.clip(b_np, 1, None).astype(int),
|
|
368
|
+
np.clip(x_np, 0.0, 1.0),
|
|
369
|
+
)
|
|
370
|
+
return self._as_tensor(result)
|
|
371
|
+
except Exception:
|
|
372
|
+
return self._betainc_integral(1, 1, self._as_tensor(x))
|
|
373
|
+
|
|
374
|
+
def _betainc_integral(self, a, b, x):
|
|
375
|
+
"""Regularized incomplete beta via trapezoidal rule on Chebyshev-mapped grid.
|
|
376
|
+
|
|
377
|
+
Uses Chebyshev-node mapping to cluster grid points near s=0 and s=1.
|
|
378
|
+
"""
|
|
379
|
+
import math as _math
|
|
380
|
+
t = self._torch
|
|
381
|
+
device = x.device
|
|
382
|
+
x = t.clamp(x, 0.0, 1.0)
|
|
383
|
+
af, bf = float(a), float(b)
|
|
384
|
+
if af < 1.0 or bf < 1.0:
|
|
385
|
+
n_grid = 64000
|
|
386
|
+
elif af < 5.0 or bf < 5.0:
|
|
387
|
+
n_grid = 16000
|
|
388
|
+
else:
|
|
389
|
+
n_grid = 8000
|
|
390
|
+
theta = t.linspace(0, _math.pi, n_grid, device=device, dtype=t.float64)
|
|
391
|
+
s = 0.5 * (1.0 + t.cos(theta)) # descending [≈1, 0]
|
|
392
|
+
s = s.flip(0) # ascending [0, ≈1]
|
|
393
|
+
eps = 1e-14
|
|
394
|
+
log_val = (a - 1) * t.log(s + 1e-300) + (b - 1) * t.log1p(-s + 1e-300)
|
|
395
|
+
log_val = t.where(t.isfinite(log_val), log_val, t.tensor(-700.0, dtype=t.float64, device=device))
|
|
396
|
+
f = t.exp(log_val)
|
|
397
|
+
beta_ab = _math.exp(_math.lgamma(af) + _math.lgamma(bf) - _math.lgamma(af + bf))
|
|
398
|
+
ds = s[1:] - s[:-1]
|
|
399
|
+
cum = t.zeros(n_grid, device=device, dtype=t.float64)
|
|
400
|
+
cum[1:] = t.cumsum((f[:-1] + f[1:]) * 0.5 * ds, dim=0)
|
|
401
|
+
x_flat = x.flatten()
|
|
402
|
+
idx = t.searchsorted(s, x_flat, right=True).clamp(1, n_grid - 1)
|
|
403
|
+
frac = (x_flat - s[idx - 1]) / (s[idx] - s[idx - 1] + 1e-300)
|
|
404
|
+
frac = frac.clamp(0.0, 1.0)
|
|
405
|
+
result = cum[idx - 1] + frac * (cum[idx] - cum[idx - 1])
|
|
406
|
+
result = result / beta_ab
|
|
407
|
+
result = t.clamp(result, 0.0, 1.0)
|
|
408
|
+
result = t.where(x_flat <= eps, 0.0, result)
|
|
409
|
+
result = t.where(x_flat >= 1 - eps, 1.0, result)
|
|
410
|
+
return result.view_as(x)
|
|
411
|
+
|
|
412
|
+
def _betainc_batch(self, a, b, x):
|
|
413
|
+
"""Batch betainc for non-scalar a, b via fused 2D-LUT interpolation.
|
|
414
|
+
|
|
415
|
+
All LUTs share the same x-grid (fixed log-spaced scheme), so we:
|
|
416
|
+
1. Build a 2D y-grid of shape (n_pairs, n_grid) for all unique (a,b) pairs
|
|
417
|
+
2. Call searchsorted ONCE to find the bracket index for all elements
|
|
418
|
+
3. Interpolate all pairs simultaneously via batched gather
|
|
419
|
+
4. Scatter results back to output positions
|
|
420
|
+
|
|
421
|
+
This avoids 100+ separate searchsorted calls, reducing overhead by ~100x.
|
|
422
|
+
"""
|
|
423
|
+
x_flat = self._as_tensor(x).flatten()
|
|
424
|
+
a_flat = self._as_tensor(a).flatten()
|
|
425
|
+
b_flat = self._as_tensor(b).flatten()
|
|
426
|
+
t = self._torch
|
|
427
|
+
|
|
428
|
+
# Clamp to >= 1 for key encoding (edge cases get overwritten by caller)
|
|
429
|
+
ai = t.clamp(t.round(a_flat).long(), 1, 100000)
|
|
430
|
+
bi = t.clamp(t.round(b_flat).long(), 1, 100000)
|
|
431
|
+
# Encode as single key for unique computation
|
|
432
|
+
keys = ai * 100000 + bi
|
|
433
|
+
unique_keys, inverse_idx = t.unique(keys, return_inverse=True)
|
|
434
|
+
|
|
435
|
+
n_pairs = unique_keys.numel()
|
|
436
|
+
n_elem = x_flat.numel()
|
|
437
|
+
|
|
438
|
+
# Build 2D grid: (n_pairs, n_grid)
|
|
439
|
+
# All LUTs share the same x-grid, so we only need one
|
|
440
|
+
y_grid = t.zeros((n_pairs, 40000), dtype=t.float64, device=self._device)
|
|
441
|
+
xg = None
|
|
442
|
+
n_actual = 0
|
|
443
|
+
failed_pairs = []
|
|
444
|
+
for pi in range(n_pairs):
|
|
445
|
+
k_val = unique_keys[pi].item()
|
|
446
|
+
a_val = k_val // 100000
|
|
447
|
+
b_val = k_val - a_val * 100000
|
|
448
|
+
try:
|
|
449
|
+
xg_i, yg_i = _get_torch_betainc_lut(a_val, b_val, self._device)
|
|
450
|
+
if xg is None:
|
|
451
|
+
xg = xg_i # all LUTs share the same x-grid
|
|
452
|
+
n_actual = len(xg_i)
|
|
453
|
+
y_grid[pi, :len(yg_i)] = yg_i
|
|
454
|
+
except Exception:
|
|
455
|
+
failed_pairs.append((pi, float(a_val), float(b_val)))
|
|
456
|
+
|
|
457
|
+
if xg is None:
|
|
458
|
+
# All LUTs failed, fall back
|
|
459
|
+
return self._betainc_integral(1, 1, x_flat)
|
|
460
|
+
|
|
461
|
+
xg = xg[:n_actual]
|
|
462
|
+
y_grid = y_grid[:, :n_actual]
|
|
463
|
+
|
|
464
|
+
# Single searchsorted for all elements
|
|
465
|
+
x_clamp = t.clamp(x_flat, 0.0, 1.0)
|
|
466
|
+
sidx = t.searchsorted(xg, x_clamp).clip(1, n_actual - 1)
|
|
467
|
+
|
|
468
|
+
# Interpolation weights (same for all pairs)
|
|
469
|
+
x0g, x1g = xg[sidx - 1], xg[sidx]
|
|
470
|
+
w = (x_clamp - x0g) / (x1g - x0g + 1e-300)
|
|
471
|
+
|
|
472
|
+
# Gather y0/y1 for all pairs simultaneously: (n_pairs, n_elem)
|
|
473
|
+
y0_all = y_grid[:, sidx - 1] # (n_pairs, n_elem)
|
|
474
|
+
y1_all = y_grid[:, sidx] # (n_pairs, n_elem)
|
|
475
|
+
y_all = y0_all + w.unsqueeze(0) * (y1_all - y0_all) # (n_pairs, n_elem)
|
|
476
|
+
y_all = y_all.clamp(0.0, 1.0)
|
|
477
|
+
|
|
478
|
+
# Scatter: select the right pair index for each element
|
|
479
|
+
# inverse_idx: (n_elem,) → indices into pair dimension
|
|
480
|
+
# y_all: (n_pairs, n_elem) → gather along dim=0
|
|
481
|
+
result = y_all[inverse_idx, t.arange(n_elem, device=self._device)]
|
|
482
|
+
if failed_pairs:
|
|
483
|
+
for pi, a_val, b_val in failed_pairs:
|
|
484
|
+
mask = inverse_idx == pi
|
|
485
|
+
if t.any(mask):
|
|
486
|
+
result[mask] = self._betainc_integral(a_val, b_val, x_clamp[mask])
|
|
487
|
+
|
|
488
|
+
return result.view(self._as_tensor(a).shape)
|
|
489
|
+
|
|
490
|
+
def betaincinv(self, a, b, y):
|
|
491
|
+
t = self._torch
|
|
492
|
+
af, bf = float(a), float(b)
|
|
493
|
+
yt = self._as_tensor(y)
|
|
494
|
+
if hasattr(t.special, "betaincinv"):
|
|
495
|
+
return t.special.betaincinv(
|
|
496
|
+
self._as_tensor(a), self._as_tensor(b), yt,
|
|
497
|
+
)
|
|
498
|
+
# For scalar a, b: LUT lookup + 1-step Newton refine
|
|
499
|
+
if not self.use_lut:
|
|
500
|
+
return self._betaincinv_newton(af, bf, yt)
|
|
501
|
+
try:
|
|
502
|
+
y_grid, x_grid = _get_torch_betaincinv_lut(af, bf, yt.device)
|
|
503
|
+
# Searchsorted to find bracket index
|
|
504
|
+
idx = t.searchsorted(y_grid, t.clamp(yt, 0.0, 1.0)).clamp(1, len(y_grid) - 1)
|
|
505
|
+
# Interpolate between two nearest LUT points
|
|
506
|
+
y0, y1 = y_grid[idx - 1], y_grid[idx]
|
|
507
|
+
x0, x1 = x_grid[idx - 1], x_grid[idx]
|
|
508
|
+
w = (yt - y0) / (y1 - y0 + 1e-300)
|
|
509
|
+
x = x0 + w * (x1 - x0)
|
|
510
|
+
x = t.clamp(x, 1e-10, 1.0 - 1e-10)
|
|
511
|
+
# 1-step Newton refine
|
|
512
|
+
import math as _math
|
|
513
|
+
beta_ab = math.exp(math.lgamma(af) + math.lgamma(bf) - math.lgamma(af + bf))
|
|
514
|
+
val = self._betainc_integral(af, bf, x)
|
|
515
|
+
deriv = t.pow(t.clamp(x, 1e-300, 1 - 1e-300), af - 1) * \
|
|
516
|
+
t.pow(t.clamp(1 - x, 1e-300, 1 - 1e-300), bf - 1) / beta_ab
|
|
517
|
+
deriv = t.clamp(deriv, 1e-300, 1e300)
|
|
518
|
+
step = (val - yt) / deriv
|
|
519
|
+
x = x - step
|
|
520
|
+
x = t.clamp(x, 1e-10, 1.0 - 1e-10)
|
|
521
|
+
return x
|
|
522
|
+
except Exception:
|
|
523
|
+
return self._betaincinv_newton(af, bf, yt)
|
|
524
|
+
|
|
525
|
+
def _betaincinv_newton(self, a, b, y):
|
|
526
|
+
"""Inverse regularized incomplete beta via damped Newton-Raphson."""
|
|
527
|
+
t = self._torch
|
|
528
|
+
device = y.device
|
|
529
|
+
y = t.clamp(y, 1e-15, 1 - 1e-15)
|
|
530
|
+
import math as _math
|
|
531
|
+
# Logit-normal approximation for initial guess
|
|
532
|
+
import scipy.special as _scsp
|
|
533
|
+
mu = _scsp.digamma(a) - _scsp.digamma(b)
|
|
534
|
+
sigma2 = 1.0 / a + 1.0 / b
|
|
535
|
+
sigma = math.sqrt(sigma2)
|
|
536
|
+
z = -_math.sqrt(2.0) * self.erfcinv(2.0 * y)
|
|
537
|
+
z = self._as_tensor(z) if not isinstance(z, t.Tensor) else z
|
|
538
|
+
logit_q = mu + sigma * z
|
|
539
|
+
x = 1.0 / (1.0 + t.exp(-logit_q))
|
|
540
|
+
x = t.clamp(x, 1e-10, 1.0 - 1e-10)
|
|
541
|
+
# Damped Newton refinement
|
|
542
|
+
beta_ab = math.exp(math.lgamma(a) + math.lgamma(b) - math.lgamma(a + b))
|
|
543
|
+
for _ in range(50):
|
|
544
|
+
val = self._betainc_integral(a, b, x)
|
|
545
|
+
diff = val - y
|
|
546
|
+
if t.max(t.abs(diff)) < 1e-13:
|
|
547
|
+
break
|
|
548
|
+
deriv = t.pow(t.clamp(x, 1e-300, 1 - 1e-300), a - 1) * \
|
|
549
|
+
t.pow(t.clamp(1 - x, 1e-300, 1 - 1e-300), b - 1) / beta_ab
|
|
550
|
+
deriv = t.clamp(deriv, 1e-300, 1e300)
|
|
551
|
+
step = diff / deriv
|
|
552
|
+
|
|
553
|
+
# Damped: backtracking to keep x in valid range
|
|
554
|
+
for _ in range(20):
|
|
555
|
+
x_new = x - step
|
|
556
|
+
if t.min(x_new) < 1e-15 or t.max(x_new) > 1.0 - 1e-15:
|
|
557
|
+
step = step * 0.5
|
|
558
|
+
else:
|
|
559
|
+
break
|
|
560
|
+
|
|
561
|
+
x = x - step
|
|
562
|
+
x = t.clamp(x, 1e-10, 1.0 - 1e-10)
|
|
563
|
+
return x
|
|
564
|
+
|
|
565
|
+
def gammainc(self, a, x):
|
|
566
|
+
return self._torch.special.gammainc(self._as_tensor(a), self._as_tensor(x))
|
|
567
|
+
|
|
568
|
+
def gammaincc(self, a, x):
|
|
569
|
+
return self._torch.special.gammaincc(self._as_tensor(a), self._as_tensor(x))
|
|
570
|
+
|
|
571
|
+
def gammaincinv(self, a, q):
|
|
572
|
+
t = self._torch
|
|
573
|
+
af = float(a)
|
|
574
|
+
qt = self._as_tensor(q)
|
|
575
|
+
if hasattr(t.special, "gammaincinv"):
|
|
576
|
+
return t.special.gammaincinv(self._as_tensor(a), qt)
|
|
577
|
+
return self._gammaincinv_newton(af, qt)
|
|
578
|
+
|
|
579
|
+
def _gammaincinv_newton(self, a, q):
|
|
580
|
+
"""Inverse regularized lower incomplete gamma via damped Newton-Raphson."""
|
|
581
|
+
t = self._torch
|
|
582
|
+
device = q.device
|
|
583
|
+
q = t.clamp(q, 1e-15, 1 - 1e-15)
|
|
584
|
+
import math
|
|
585
|
+
at = t.tensor(a, dtype=t.float64, device=device)
|
|
586
|
+
|
|
587
|
+
# Wilson-Hilferty initial guess (much better than erfinv-based)
|
|
588
|
+
# For gamma(a,1): P(a,x) ≈ Φ((x/a)^(1/3) - (1 - 1/(9a))) / sqrt(1/(9a))
|
|
589
|
+
z = math.sqrt(2.0) * t.erfinv(2.0 * q - 1.0)
|
|
590
|
+
c = 1.0 - 1.0 / (9.0 * a)
|
|
591
|
+
s = 1.0 / math.sqrt(9.0 * a)
|
|
592
|
+
u = z * s + c
|
|
593
|
+
x = a * t.pow(u, 3.0)
|
|
594
|
+
x = t.clamp(x, 1e-10, 1e6)
|
|
595
|
+
|
|
596
|
+
lg_a = math.lgamma(a)
|
|
597
|
+
for _ in range(50):
|
|
598
|
+
val = t.special.gammainc(at, x)
|
|
599
|
+
diff = val - q
|
|
600
|
+
if t.max(t.abs(diff)) < 1e-13:
|
|
601
|
+
break
|
|
602
|
+
# derivative: d/dx P(a,x) = x^(a-1) * e^(-x) / Gamma(a)
|
|
603
|
+
log_deriv = (a - 1.0) * t.log(t.clamp(x, 1e-300, None)) - x - lg_a
|
|
604
|
+
deriv = t.exp(log_deriv)
|
|
605
|
+
deriv = t.clamp(deriv, 1e-300, 1e300)
|
|
606
|
+
step = diff / deriv
|
|
607
|
+
|
|
608
|
+
# Damped: backtracking line search to prevent oscillation
|
|
609
|
+
# Accept full step if it stays in bounds; otherwise halve
|
|
610
|
+
damped = False
|
|
611
|
+
for _ in range(20):
|
|
612
|
+
x_new = x - step
|
|
613
|
+
if t.min(x_new) < 1e-15 or t.max(x_new) > 2e6:
|
|
614
|
+
step = step * 0.5
|
|
615
|
+
damped = True
|
|
616
|
+
else:
|
|
617
|
+
break
|
|
618
|
+
|
|
619
|
+
x = x - step
|
|
620
|
+
x = t.clamp(x, 1e-10, 1e6)
|
|
621
|
+
return x
|
|
622
|
+
|
|
623
|
+
def gammaln(self, x):
|
|
624
|
+
return self._torch.lgamma(self._as_tensor(x))
|
|
625
|
+
|
|
626
|
+
def erf(self, x):
|
|
627
|
+
return self._torch.erf(self._as_tensor(x))
|
|
628
|
+
|
|
629
|
+
def erfc(self, x):
|
|
630
|
+
return self._torch.erfc(self._as_tensor(x))
|
|
631
|
+
|
|
632
|
+
def erfcinv(self, y):
|
|
633
|
+
t = self._torch
|
|
634
|
+
yt = self._as_tensor(y)
|
|
635
|
+
if hasattr(t.special, "erfcinv"):
|
|
636
|
+
return t.special.erfcinv(yt)
|
|
637
|
+
# Fallback: erfcinv(y) = erfinv(1 - y)
|
|
638
|
+
return t.erfinv(1.0 - yt)
|
|
639
|
+
|
|
640
|
+
def sqrt(self, x):
|
|
641
|
+
return self._torch.sqrt(self._as_tensor(x))
|
|
642
|
+
|
|
643
|
+
@property
|
|
644
|
+
def pi(self):
|
|
645
|
+
return self._torch.tensor(math.pi, dtype=self._torch.float64, device=self._device)
|
|
646
|
+
|
|
647
|
+
def clip(self, x, lo, hi):
|
|
648
|
+
return self._torch.clamp(x, lo, hi)
|
|
649
|
+
|
|
650
|
+
def where(self, cond, x, y):
|
|
651
|
+
t = self._torch
|
|
652
|
+
# torch.where requires boolean condition tensor
|
|
653
|
+
if isinstance(cond, t.Tensor) and cond.dtype != t.bool:
|
|
654
|
+
cond = cond.to(dtype=t.bool)
|
|
655
|
+
return t.where(cond, x, y)
|
|
656
|
+
|
|
657
|
+
def as_float64(self, x):
|
|
658
|
+
return self._as_tensor(x)
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
# =============================================================================
|
|
662
|
+
# SciPy / NumPy backend
|
|
663
|
+
# =============================================================================
|
|
664
|
+
|
|
665
|
+
class ScipySpecialFunctions:
|
|
666
|
+
"""Special functions via scipy.special (pure NumPy / CPU).
|
|
667
|
+
|
|
668
|
+
Inverse functions (gammaincinv, betaincinv) use cached LUT + interpolation
|
|
669
|
+
for ~100ms evaluation on 1M points (vs ~3000ms for scipy's iterative solver).
|
|
670
|
+
Accuracy: ~1e-5 for typical parameter ranges.
|
|
671
|
+
For edge-case parameters (extreme a, b), falls back to scipy for full accuracy.
|
|
672
|
+
"""
|
|
673
|
+
|
|
674
|
+
def __init__(self, *, use_lut: bool = True):
|
|
675
|
+
import scipy.special as scsp
|
|
676
|
+
self._scsp = scsp
|
|
677
|
+
self.use_lut = use_lut
|
|
678
|
+
# LUT cache for inverse functions (scalar a/b cases, well-behaved parameters)
|
|
679
|
+
self._gammaincinv_lut = {}
|
|
680
|
+
self._betaincinv_lut = {}
|
|
681
|
+
|
|
682
|
+
@staticmethod
|
|
683
|
+
@lru_cache(maxsize=256)
|
|
684
|
+
def _make_gammaincinv_lut(a, n_grid):
|
|
685
|
+
"""Build LUT: x_grid -> y = gammainc(a, x_grid)."""
|
|
686
|
+
import scipy.special as _scsp
|
|
687
|
+
x_max = a + 20 * math.sqrt(max(a, 0.1)) + 10
|
|
688
|
+
x_max = min(x_max, 1e6)
|
|
689
|
+
n_log = n_grid // 3
|
|
690
|
+
n_lin = n_grid - n_log
|
|
691
|
+
x_lo = np.logspace(-15, math.log10(max(x_max, 1e-10)), n_log, endpoint=False)
|
|
692
|
+
x_hi = np.linspace(x_lo[-1] if len(x_lo) > 0 else 0, x_max, n_lin + 1)[1:]
|
|
693
|
+
x_grid = np.concatenate([x_lo, x_hi])
|
|
694
|
+
if len(x_grid) < n_grid:
|
|
695
|
+
extra = np.linspace(x_grid[-1], x_max, n_grid - len(x_grid) + 2)[1:]
|
|
696
|
+
x_grid = np.concatenate([x_grid, extra])
|
|
697
|
+
x_grid = x_grid[:n_grid]
|
|
698
|
+
y_grid = _scsp.gammainc(a, x_grid)
|
|
699
|
+
y_grid[0] = 0.0
|
|
700
|
+
y_grid[-1] = 1.0
|
|
701
|
+
return x_grid, y_grid
|
|
702
|
+
|
|
703
|
+
@staticmethod
|
|
704
|
+
@lru_cache(maxsize=256)
|
|
705
|
+
def _make_betaincinv_lut(a, b, n_grid):
|
|
706
|
+
"""Build LUT: x_grid -> y = betainc(a, b, x_grid).
|
|
707
|
+
|
|
708
|
+
Uses log spacing near both boundaries for better precision when
|
|
709
|
+
a or b is small (e.g. b=0.5 for t/f distributions).
|
|
710
|
+
"""
|
|
711
|
+
import scipy.special as _scsp
|
|
712
|
+
eps = 1e-15
|
|
713
|
+
# Log spacing: 40% near each boundary, 20% in the middle
|
|
714
|
+
n_edge = int(n_grid * 0.4)
|
|
715
|
+
n_mid = n_grid - 2 * n_edge
|
|
716
|
+
x_lo = np.logspace(np.log10(eps), np.log10(0.01), n_edge)
|
|
717
|
+
x_mid = np.linspace(0.01, 0.99, n_mid + 2)[1:-1]
|
|
718
|
+
x_hi = 1.0 - np.logspace(np.log10(eps), np.log10(0.01), n_edge)[::-1]
|
|
719
|
+
x_grid = np.concatenate([x_lo, x_mid, x_hi])
|
|
720
|
+
if len(x_grid) > n_grid:
|
|
721
|
+
x_grid = x_grid[:n_grid]
|
|
722
|
+
y_grid = _scsp.betainc(a, b, x_grid)
|
|
723
|
+
return x_grid, y_grid
|
|
724
|
+
|
|
725
|
+
@staticmethod
|
|
726
|
+
def _inverse_lut(q_or_y, x_grid, y_grid):
|
|
727
|
+
"""Use LUT for inverse: given q, find x such that f(x) = q."""
|
|
728
|
+
idx = np.searchsorted(y_grid, q_or_y, side='left').clip(1, len(y_grid) - 1)
|
|
729
|
+
frac = (q_or_y - y_grid[idx - 1]) / (y_grid[idx] - y_grid[idx - 1] + 1e-300)
|
|
730
|
+
frac = np.clip(frac, 0.0, 1.0)
|
|
731
|
+
return x_grid[idx - 1] + frac * (x_grid[idx] - x_grid[idx - 1])
|
|
732
|
+
|
|
733
|
+
def betainc(self, a, b, x):
|
|
734
|
+
return self._scsp.betainc(a, b, np.asarray(x, dtype=np.float64))
|
|
735
|
+
|
|
736
|
+
def betaincinv(self, a, b, y):
|
|
737
|
+
arr = np.asarray(y, dtype=np.float64)
|
|
738
|
+
try:
|
|
739
|
+
af, bf = float(a), float(b)
|
|
740
|
+
except (TypeError, ValueError):
|
|
741
|
+
return self._scsp.betaincinv(a, b, arr)
|
|
742
|
+
if not self.use_lut:
|
|
743
|
+
return self._scsp.betaincinv(af, bf, arr)
|
|
744
|
+
if af < 0.3 or bf < 0.3 or af > 50 or bf > 50 or abs(af - bf) > 30:
|
|
745
|
+
return self._scsp.betaincinv(af, bf, arr)
|
|
746
|
+
# LUT + 1-step Newton refinement
|
|
747
|
+
key = (af, bf)
|
|
748
|
+
if key not in self._betaincinv_lut:
|
|
749
|
+
x_grid, y_grid = self._make_betaincinv_lut(af, bf, 20000)
|
|
750
|
+
self._betaincinv_lut[key] = (x_grid, y_grid)
|
|
751
|
+
x_grid, y_grid = self._betaincinv_lut[key]
|
|
752
|
+
x0 = self._inverse_lut(arr, x_grid, y_grid)
|
|
753
|
+
# 1 step of Newton
|
|
754
|
+
log_beta = math.lgamma(af) + math.lgamma(bf) - math.lgamma(af + bf)
|
|
755
|
+
p = self._scsp.betainc(af, bf, x0)
|
|
756
|
+
diff = p - arr
|
|
757
|
+
log_deriv = (af - 1.0) * np.log(np.clip(x0, 1e-300, None)) + \
|
|
758
|
+
(bf - 1.0) * np.log(np.clip(1.0 - x0, 1e-300, None)) - log_beta
|
|
759
|
+
deriv = np.exp(log_deriv)
|
|
760
|
+
x1 = x0 - diff / np.clip(deriv, 1e-300, 1e300)
|
|
761
|
+
return np.clip(x1, 1e-15, 1.0 - 1e-15)
|
|
762
|
+
|
|
763
|
+
def gammainc(self, a, x):
|
|
764
|
+
return self._scsp.gammainc(np.asarray(a, dtype=np.float64),
|
|
765
|
+
np.asarray(x, dtype=np.float64))
|
|
766
|
+
|
|
767
|
+
def gammaincc(self, a, x):
|
|
768
|
+
return self._scsp.gammaincc(np.asarray(a, dtype=np.float64),
|
|
769
|
+
np.asarray(x, dtype=np.float64))
|
|
770
|
+
|
|
771
|
+
def gammaincinv(self, a, q):
|
|
772
|
+
arr = np.asarray(q, dtype=np.float64)
|
|
773
|
+
try:
|
|
774
|
+
af = float(a)
|
|
775
|
+
except (TypeError, ValueError):
|
|
776
|
+
return self._scsp.gammaincinv(a, arr)
|
|
777
|
+
if not self.use_lut:
|
|
778
|
+
return self._scsp.gammaincinv(af, arr)
|
|
779
|
+
if af < 1.0:
|
|
780
|
+
return self._scsp.gammaincinv(af, arr)
|
|
781
|
+
# LUT + 1-step Newton refinement
|
|
782
|
+
key = (af,)
|
|
783
|
+
if key not in self._gammaincinv_lut:
|
|
784
|
+
x_grid, y_grid = self._make_gammaincinv_lut(af, 20000)
|
|
785
|
+
self._gammaincinv_lut[key] = (x_grid, y_grid)
|
|
786
|
+
x_grid, y_grid = self._gammaincinv_lut[key]
|
|
787
|
+
x0 = self._inverse_lut(arr, x_grid, y_grid)
|
|
788
|
+
# 1 step of Newton: x = x0 - (P(a, x0) - q) / P'(a, x0)
|
|
789
|
+
log_ga = math.lgamma(af)
|
|
790
|
+
p = self._scsp.gammainc(af, x0)
|
|
791
|
+
diff = p - arr
|
|
792
|
+
log_deriv = (af - 1.0) * np.log(np.clip(x0, 1e-300, None)) - x0 - log_ga
|
|
793
|
+
deriv = np.exp(log_deriv)
|
|
794
|
+
x1 = x0 - diff / np.clip(deriv, 1e-300, 1e300)
|
|
795
|
+
return np.clip(x1, 1e-15, 1e6)
|
|
796
|
+
|
|
797
|
+
def gammaln(self, x):
|
|
798
|
+
return self._scsp.gammaln(np.asarray(x, dtype=np.float64))
|
|
799
|
+
|
|
800
|
+
def erf(self, x):
|
|
801
|
+
return self._scsp.erf(np.asarray(x, dtype=np.float64))
|
|
802
|
+
|
|
803
|
+
def erfc(self, x):
|
|
804
|
+
return self._scsp.erfc(np.asarray(x, dtype=np.float64))
|
|
805
|
+
|
|
806
|
+
def erfcinv(self, y):
|
|
807
|
+
return self._scsp.erfcinv(np.asarray(y, dtype=np.float64))
|
|
808
|
+
|
|
809
|
+
def sqrt(self, x):
|
|
810
|
+
return np.sqrt(np.asarray(x, dtype=np.float64))
|
|
811
|
+
|
|
812
|
+
@property
|
|
813
|
+
def pi(self):
|
|
814
|
+
return np.pi
|
|
815
|
+
|
|
816
|
+
def clip(self, x, lo, hi):
|
|
817
|
+
return np.clip(x, lo, hi)
|
|
818
|
+
|
|
819
|
+
def where(self, cond, x, y):
|
|
820
|
+
return np.where(cond, x, y)
|
|
821
|
+
|
|
822
|
+
def as_float64(self, x):
|
|
823
|
+
return np.asarray(x, dtype=np.float64)
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
# =============================================================================
|
|
827
|
+
# Distribution base classes — parameterized by SpecialFunctions
|
|
828
|
+
# =============================================================================
|
|
829
|
+
|
|
830
|
+
_T_PPF_BISECT_LOWER = -64.0
|
|
831
|
+
_T_PPF_BISECT_UPPER = 64.0
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
class NormDistributionBase:
|
|
835
|
+
"""scipy.stats.norm-like distribution, parameterized by SpecialFunctions."""
|
|
836
|
+
|
|
837
|
+
def __init__(self, sf: SpecialFunctions):
|
|
838
|
+
self._sf = sf
|
|
839
|
+
|
|
840
|
+
def _cdf_standard(self, x):
|
|
841
|
+
sf = self._sf
|
|
842
|
+
return 0.5 * (1.0 + sf.erf(x / sf.sqrt(2.0)))
|
|
843
|
+
|
|
844
|
+
def _sf_standard(self, x):
|
|
845
|
+
return sf_safe_mul(self._sf.erfc(x / self._sf.sqrt(2.0)), 0.5, self._sf)
|
|
846
|
+
|
|
847
|
+
def _ppf_standard(self, q):
|
|
848
|
+
return -self._sf.sqrt(2.0) * self._sf.erfcinv(2.0 * q)
|
|
849
|
+
|
|
850
|
+
def _isf_standard(self, q):
|
|
851
|
+
return self._ppf_standard(1.0 - q)
|
|
852
|
+
|
|
853
|
+
def _two_sided_pvalue_standard(self, stat_abs):
|
|
854
|
+
sf = self._sf
|
|
855
|
+
return sf.clip(2.0 * self._sf_standard(sf.as_float64(stat_abs)), 0.0, 1.0)
|
|
856
|
+
|
|
857
|
+
def _two_sided_critical_value_standard(self, alpha):
|
|
858
|
+
sf = self._sf
|
|
859
|
+
a = float(alpha)
|
|
860
|
+
if not (0.0 < a < 1.0):
|
|
861
|
+
return sf.as_float64(float("nan"))
|
|
862
|
+
return self._ppf_standard(1.0 - a / 2.0)
|
|
863
|
+
|
|
864
|
+
def cdf(self, x, *, loc=0.0, scale=1.0):
|
|
865
|
+
sf = self._sf
|
|
866
|
+
scale_f = float(scale)
|
|
867
|
+
if scale_f <= 0:
|
|
868
|
+
return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
|
|
869
|
+
x_std = (sf.as_float64(x) - float(loc)) / scale_f
|
|
870
|
+
return self._cdf_standard(x_std)
|
|
871
|
+
|
|
872
|
+
def sf(self, x, *, loc=0.0, scale=1.0):
|
|
873
|
+
sf = self._sf
|
|
874
|
+
scale_f = float(scale)
|
|
875
|
+
if scale_f <= 0:
|
|
876
|
+
return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
|
|
877
|
+
x_std = (sf.as_float64(x) - float(loc)) / scale_f
|
|
878
|
+
return self._sf_standard(x_std)
|
|
879
|
+
|
|
880
|
+
def ppf(self, q, *, loc=0.0, scale=1.0):
|
|
881
|
+
sf = self._sf
|
|
882
|
+
q_f = sf.as_float64(q)
|
|
883
|
+
scale_f = float(scale)
|
|
884
|
+
if scale_f <= 0:
|
|
885
|
+
return sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
886
|
+
return float(loc) + scale_f * self._ppf_standard(q_f)
|
|
887
|
+
|
|
888
|
+
def isf(self, q, *, loc=0.0, scale=1.0):
|
|
889
|
+
sf = self._sf
|
|
890
|
+
q_f = sf.as_float64(q)
|
|
891
|
+
scale_f = float(scale)
|
|
892
|
+
if scale_f <= 0:
|
|
893
|
+
return sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
894
|
+
return float(loc) + scale_f * self._isf_standard(q_f)
|
|
895
|
+
|
|
896
|
+
def pdf(self, x, *, loc=0.0, scale=1.0):
|
|
897
|
+
sf = self._sf
|
|
898
|
+
scale_f = float(scale)
|
|
899
|
+
x_f = sf.as_float64(x)
|
|
900
|
+
if scale_f <= 0:
|
|
901
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
902
|
+
z = (x_f - float(loc)) / scale_f
|
|
903
|
+
norm_const = sf.sqrt(2.0 * sf.pi)
|
|
904
|
+
return sf.exp(-0.5 * sf.square(z)) / (scale_f * norm_const)
|
|
905
|
+
|
|
906
|
+
def two_sided_pvalue(self, stat_abs):
|
|
907
|
+
return self._two_sided_pvalue_standard(stat_abs)
|
|
908
|
+
|
|
909
|
+
def two_sided_critical_value(self, alpha):
|
|
910
|
+
return self._two_sided_critical_value_standard(alpha)
|
|
911
|
+
|
|
912
|
+
def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
913
|
+
return _rvs_normal(self._sf, size=size, loc=loc, scale=scale)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
class TDistributionBase:
|
|
917
|
+
"""scipy.stats.t-like distribution, parameterized by SpecialFunctions."""
|
|
918
|
+
|
|
919
|
+
def __init__(self, sf: SpecialFunctions):
|
|
920
|
+
self._sf = sf
|
|
921
|
+
|
|
922
|
+
def _cdf_standard(self, x, df):
|
|
923
|
+
sf = self._sf
|
|
924
|
+
df_val = float(df)
|
|
925
|
+
if df_val <= 0:
|
|
926
|
+
return sf.where(x * 0 + 1, float("nan"), float("nan"))
|
|
927
|
+
z = df_val / (df_val + sf.square(sf.abs(x)))
|
|
928
|
+
ibeta = sf.betainc(df_val / 2.0, 0.5, z)
|
|
929
|
+
lower_tail = 0.5 * ibeta
|
|
930
|
+
return sf.where(x >= 0.0, 1.0 - lower_tail, lower_tail)
|
|
931
|
+
|
|
932
|
+
def _sf_standard(self, x, df):
|
|
933
|
+
return sf_safe_sub(1.0, self._cdf_standard(x, df), self._sf)
|
|
934
|
+
|
|
935
|
+
def _two_sided_pvalue_standard(self, stat_abs, df):
|
|
936
|
+
sf = self._sf
|
|
937
|
+
df_val = float(df)
|
|
938
|
+
if df_val <= 0:
|
|
939
|
+
return sf.where(stat_abs * 0 + 1, float("nan"), float("nan"))
|
|
940
|
+
z = df_val / (df_val + sf.square(sf.abs(stat_abs)))
|
|
941
|
+
return sf.betainc(df_val / 2.0, 0.5, z)
|
|
942
|
+
|
|
943
|
+
def _ppf_standard(self, q, df, *, max_bisect_steps=60):
|
|
944
|
+
sf = self._sf
|
|
945
|
+
df_val = float(df)
|
|
946
|
+
if df_val <= 0:
|
|
947
|
+
return sf.where(sf.as_float64(q) * 0 + 1, float("nan"), float("nan"))
|
|
948
|
+
|
|
949
|
+
q_f = sf.as_float64(q)
|
|
950
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
951
|
+
out = sf.where(q_f == 0.0, -float("inf"), out)
|
|
952
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
953
|
+
|
|
954
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
955
|
+
if not bool(sf.any(valid)):
|
|
956
|
+
return out
|
|
957
|
+
|
|
958
|
+
try:
|
|
959
|
+
tail = sf.minimum(q_f, 1.0 - q_f)
|
|
960
|
+
y = 2.0 * tail
|
|
961
|
+
y_inv = sf.betaincinv(df_val / 2.0, 0.5, y)
|
|
962
|
+
x_pos = sf.sqrt(df_val * (1.0 - y_inv) / y_inv)
|
|
963
|
+
quant = sf.where(q_f >= 0.5, x_pos, -x_pos)
|
|
964
|
+
return sf.where(valid, quant, out)
|
|
965
|
+
except Exception:
|
|
966
|
+
return self._ppf_bisect(q_f, df_val, valid, out, max_bisect_steps)
|
|
967
|
+
|
|
968
|
+
def _ppf_bisect(self, q, df_val, valid, out, steps):
|
|
969
|
+
sf = self._sf
|
|
970
|
+
lo = sf.where(q * 0 + 1, _T_PPF_BISECT_LOWER, _T_PPF_BISECT_LOWER)
|
|
971
|
+
hi = sf.where(q * 0 + 1, _T_PPF_BISECT_UPPER, _T_PPF_BISECT_UPPER)
|
|
972
|
+
for _ in range(max(int(steps), 1)):
|
|
973
|
+
mid = 0.5 * (lo + hi)
|
|
974
|
+
cdf_mid = self._cdf_standard(mid, df_val)
|
|
975
|
+
go_right = cdf_mid < q
|
|
976
|
+
lo = sf.where(go_right, mid, lo)
|
|
977
|
+
hi = sf.where(go_right, hi, mid)
|
|
978
|
+
quant = 0.5 * (lo + hi)
|
|
979
|
+
return sf.where(valid, quant, out)
|
|
980
|
+
|
|
981
|
+
def cdf(self, x, df, *, loc=0.0, scale=1.0):
|
|
982
|
+
sf = self._sf
|
|
983
|
+
scale_f = float(scale)
|
|
984
|
+
if scale_f <= 0:
|
|
985
|
+
return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
|
|
986
|
+
x_std = (sf.as_float64(x) - float(loc)) / scale_f
|
|
987
|
+
return self._cdf_standard(x_std, df)
|
|
988
|
+
|
|
989
|
+
def sf(self, x, df, *, loc=0.0, scale=1.0):
|
|
990
|
+
sf = self._sf
|
|
991
|
+
scale_f = float(scale)
|
|
992
|
+
if scale_f <= 0:
|
|
993
|
+
return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
|
|
994
|
+
x_std = (sf.as_float64(x) - float(loc)) / scale_f
|
|
995
|
+
return self._sf_standard(x_std, df)
|
|
996
|
+
|
|
997
|
+
def ppf(self, q, df, *, loc=0.0, scale=1.0, max_bisect_steps=60):
|
|
998
|
+
sf = self._sf
|
|
999
|
+
scale_f = float(scale)
|
|
1000
|
+
if scale_f <= 0:
|
|
1001
|
+
return sf.where(sf.as_float64(q) * 0 + 1, float("nan"), float("nan"))
|
|
1002
|
+
return float(loc) + scale_f * self._ppf_standard(q, df, max_bisect_steps=max_bisect_steps)
|
|
1003
|
+
|
|
1004
|
+
def isf(self, q, df, *, loc=0.0, scale=1.0, max_bisect_steps=60):
|
|
1005
|
+
sf = self._sf
|
|
1006
|
+
scale_f = float(scale)
|
|
1007
|
+
if scale_f <= 0:
|
|
1008
|
+
return sf.where(sf.as_float64(q) * 0 + 1, float("nan"), float("nan"))
|
|
1009
|
+
return float(loc) + scale_f * self._ppf_standard(1.0 - sf.as_float64(q), df, max_bisect_steps=max_bisect_steps)
|
|
1010
|
+
|
|
1011
|
+
def pdf(self, x, df, *, loc=0.0, scale=1.0):
|
|
1012
|
+
sf = self._sf
|
|
1013
|
+
x_f = sf.as_float64(x)
|
|
1014
|
+
df_val = float(df)
|
|
1015
|
+
scale_f = float(scale)
|
|
1016
|
+
if df_val <= 0.0 or scale_f <= 0.0:
|
|
1017
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1018
|
+
z = (x_f - float(loc)) / scale_f
|
|
1019
|
+
half_nu = df_val / 2.0
|
|
1020
|
+
log_coef = (
|
|
1021
|
+
sf.gammaln((df_val + 1.0) / 2.0)
|
|
1022
|
+
- sf.gammaln(half_nu)
|
|
1023
|
+
- 0.5 * (sf.log(df_val) + sf.log(sf.pi))
|
|
1024
|
+
)
|
|
1025
|
+
log_pdf = (
|
|
1026
|
+
log_coef
|
|
1027
|
+
- ((df_val + 1.0) / 2.0) * sf.log1p(sf.square(z) / df_val)
|
|
1028
|
+
- sf.log(scale_f)
|
|
1029
|
+
)
|
|
1030
|
+
return sf.exp(log_pdf)
|
|
1031
|
+
|
|
1032
|
+
def two_sided_pvalue(self, stat_abs, df):
|
|
1033
|
+
return self._two_sided_pvalue_standard(stat_abs, df)
|
|
1034
|
+
|
|
1035
|
+
def two_sided_critical_value(self, alpha, df, *, max_bisect_steps=60):
|
|
1036
|
+
sf = self._sf
|
|
1037
|
+
a = float(alpha)
|
|
1038
|
+
if not (0.0 < a < 1.0):
|
|
1039
|
+
return sf.as_float64(float("nan"))
|
|
1040
|
+
return self._ppf_standard(1.0 - a / 2.0, df, max_bisect_steps=max_bisect_steps)
|
|
1041
|
+
|
|
1042
|
+
def rvs(self, df, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1043
|
+
return _rvs_t(self._sf, df=df, size=size, loc=loc, scale=scale)
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
class UniformDistributionBase:
|
|
1047
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1048
|
+
self._sf = sf
|
|
1049
|
+
|
|
1050
|
+
def cdf(self, x, *, loc=0.0, scale=1.0):
|
|
1051
|
+
sf = self._sf
|
|
1052
|
+
scale_f = float(scale)
|
|
1053
|
+
x_f = sf.as_float64(x)
|
|
1054
|
+
if scale_f <= 0.0:
|
|
1055
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1056
|
+
z = (x_f - float(loc)) / scale_f
|
|
1057
|
+
return sf.clip(z, 0.0, 1.0)
|
|
1058
|
+
|
|
1059
|
+
def sf(self, x, *, loc=0.0, scale=1.0):
|
|
1060
|
+
return sf_safe_sub(1.0, self.cdf(x, loc=loc, scale=scale), self._sf)
|
|
1061
|
+
|
|
1062
|
+
def ppf(self, q, *, loc=0.0, scale=1.0):
|
|
1063
|
+
sf = self._sf
|
|
1064
|
+
scale_f = float(scale)
|
|
1065
|
+
q_f = sf.as_float64(q)
|
|
1066
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1067
|
+
if scale_f <= 0.0:
|
|
1068
|
+
return out
|
|
1069
|
+
valid = (q_f >= 0.0) & (q_f <= 1.0)
|
|
1070
|
+
return sf.where(valid, float(loc) + scale_f * q_f, out)
|
|
1071
|
+
|
|
1072
|
+
def isf(self, q, *, loc=0.0, scale=1.0):
|
|
1073
|
+
return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
|
|
1074
|
+
|
|
1075
|
+
def pdf(self, x, *, loc=0.0, scale=1.0):
|
|
1076
|
+
sf = self._sf
|
|
1077
|
+
scale_f = float(scale)
|
|
1078
|
+
x_f = sf.as_float64(x)
|
|
1079
|
+
if scale_f <= 0.0:
|
|
1080
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1081
|
+
z = (x_f - float(loc)) / scale_f
|
|
1082
|
+
in_support = (z >= 0.0) & (z <= 1.0)
|
|
1083
|
+
return sf.where(in_support, 1.0 / scale_f, 0.0)
|
|
1084
|
+
|
|
1085
|
+
def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1086
|
+
return _rvs_uniform(self._sf, size=size, loc=loc, scale=scale)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
class ExponDistributionBase:
|
|
1090
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1091
|
+
self._sf = sf
|
|
1092
|
+
|
|
1093
|
+
def cdf(self, x, *, loc=0.0, scale=1.0):
|
|
1094
|
+
sf = self._sf
|
|
1095
|
+
x_f = sf.as_float64(x)
|
|
1096
|
+
scale_f = float(scale)
|
|
1097
|
+
if scale_f <= 0.0:
|
|
1098
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1099
|
+
z = (x_f - float(loc)) / scale_f
|
|
1100
|
+
return sf.where(z <= 0.0, 0.0, 1.0 - sf.exp(-z))
|
|
1101
|
+
|
|
1102
|
+
def sf(self, x, *, loc=0.0, scale=1.0):
|
|
1103
|
+
sf = self._sf
|
|
1104
|
+
x_f = sf.as_float64(x)
|
|
1105
|
+
scale_f = float(scale)
|
|
1106
|
+
if scale_f <= 0.0:
|
|
1107
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1108
|
+
z = (x_f - float(loc)) / scale_f
|
|
1109
|
+
return sf.where(z <= 0.0, 1.0, sf.exp(-z))
|
|
1110
|
+
|
|
1111
|
+
def ppf(self, q, *, loc=0.0, scale=1.0):
|
|
1112
|
+
sf = self._sf
|
|
1113
|
+
scale_f = float(scale)
|
|
1114
|
+
q_f = sf.as_float64(q)
|
|
1115
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1116
|
+
if scale_f <= 0.0:
|
|
1117
|
+
return out
|
|
1118
|
+
out = sf.where(q_f == 0.0, float(loc), out)
|
|
1119
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1120
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1121
|
+
return sf.where(valid, float(loc) - scale_f * sf.log1p(-q_f), out)
|
|
1122
|
+
|
|
1123
|
+
def isf(self, q, *, loc=0.0, scale=1.0):
|
|
1124
|
+
return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
|
|
1125
|
+
|
|
1126
|
+
def pdf(self, x, *, loc=0.0, scale=1.0):
|
|
1127
|
+
sf = self._sf
|
|
1128
|
+
x_f = sf.as_float64(x)
|
|
1129
|
+
scale_f = float(scale)
|
|
1130
|
+
if scale_f <= 0.0:
|
|
1131
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1132
|
+
z = (x_f - float(loc)) / scale_f
|
|
1133
|
+
return sf.where(z >= 0.0, sf.exp(-z) / scale_f, 0.0)
|
|
1134
|
+
|
|
1135
|
+
def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1136
|
+
return _rvs_expon(self._sf, size=size, loc=loc, scale=scale)
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
class CauchyDistributionBase:
|
|
1140
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1141
|
+
self._sf = sf
|
|
1142
|
+
|
|
1143
|
+
def cdf(self, x, *, loc=0.0, scale=1.0):
|
|
1144
|
+
sf = self._sf
|
|
1145
|
+
scale_f = float(scale)
|
|
1146
|
+
x_f = sf.as_float64(x)
|
|
1147
|
+
if scale_f <= 0.0:
|
|
1148
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1149
|
+
z = (x_f - float(loc)) / scale_f
|
|
1150
|
+
return 0.5 + sf.atan(z) / sf.pi
|
|
1151
|
+
|
|
1152
|
+
def sf(self, x, *, loc=0.0, scale=1.0):
|
|
1153
|
+
return sf_safe_sub(1.0, self.cdf(x, loc=loc, scale=scale), self._sf)
|
|
1154
|
+
|
|
1155
|
+
def ppf(self, q, *, loc=0.0, scale=1.0):
|
|
1156
|
+
sf = self._sf
|
|
1157
|
+
scale_f = float(scale)
|
|
1158
|
+
q_f = sf.as_float64(q)
|
|
1159
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1160
|
+
if scale_f <= 0.0:
|
|
1161
|
+
return out
|
|
1162
|
+
out = sf.where(q_f == 0.0, -float("inf"), out)
|
|
1163
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1164
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1165
|
+
return sf.where(valid, float(loc) + scale_f * sf.tan(sf.pi * (q_f - 0.5)), out)
|
|
1166
|
+
|
|
1167
|
+
def isf(self, q, *, loc=0.0, scale=1.0):
|
|
1168
|
+
return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
|
|
1169
|
+
|
|
1170
|
+
def pdf(self, x, *, loc=0.0, scale=1.0):
|
|
1171
|
+
sf = self._sf
|
|
1172
|
+
scale_f = float(scale)
|
|
1173
|
+
x_f = sf.as_float64(x)
|
|
1174
|
+
if scale_f <= 0.0:
|
|
1175
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1176
|
+
z = (x_f - float(loc)) / scale_f
|
|
1177
|
+
return 1.0 / (sf.pi * scale_f * (1.0 + sf.square(z)))
|
|
1178
|
+
|
|
1179
|
+
def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1180
|
+
return _rvs_cauchy(self._sf, size=size, loc=loc, scale=scale)
|
|
1181
|
+
|
|
1182
|
+
|
|
1183
|
+
class LaplaceDistributionBase:
|
|
1184
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1185
|
+
self._sf = sf
|
|
1186
|
+
|
|
1187
|
+
def cdf(self, x, *, loc=0.0, scale=1.0):
|
|
1188
|
+
sf = self._sf
|
|
1189
|
+
scale_f = float(scale)
|
|
1190
|
+
x_f = sf.as_float64(x)
|
|
1191
|
+
if scale_f <= 0.0:
|
|
1192
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1193
|
+
z = (x_f - float(loc)) / scale_f
|
|
1194
|
+
return sf.where(z < 0.0, 0.5 * sf.exp(z), 1.0 - 0.5 * sf.exp(-z))
|
|
1195
|
+
|
|
1196
|
+
def sf(self, x, *, loc=0.0, scale=1.0):
|
|
1197
|
+
sf = self._sf
|
|
1198
|
+
scale_f = float(scale)
|
|
1199
|
+
x_f = sf.as_float64(x)
|
|
1200
|
+
if scale_f <= 0.0:
|
|
1201
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1202
|
+
z = (x_f - float(loc)) / scale_f
|
|
1203
|
+
return sf.where(z < 0.0, 1.0 - 0.5 * sf.exp(z), 0.5 * sf.exp(-z))
|
|
1204
|
+
|
|
1205
|
+
def ppf(self, q, *, loc=0.0, scale=1.0):
|
|
1206
|
+
sf = self._sf
|
|
1207
|
+
scale_f = float(scale)
|
|
1208
|
+
q_f = sf.as_float64(q)
|
|
1209
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1210
|
+
if scale_f <= 0.0:
|
|
1211
|
+
return out
|
|
1212
|
+
out = sf.where(q_f == 0.0, -float("inf"), out)
|
|
1213
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1214
|
+
lower = (q_f > 0.0) & (q_f < 0.5)
|
|
1215
|
+
upper = (q_f >= 0.5) & (q_f < 1.0)
|
|
1216
|
+
out = sf.where(lower, float(loc) + scale_f * sf.log(2.0 * q_f), out)
|
|
1217
|
+
out = sf.where(upper, float(loc) - scale_f * sf.log(2.0 * (1.0 - q_f)), out)
|
|
1218
|
+
return out
|
|
1219
|
+
|
|
1220
|
+
def isf(self, q, *, loc=0.0, scale=1.0):
|
|
1221
|
+
return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
|
|
1222
|
+
|
|
1223
|
+
def pdf(self, x, *, loc=0.0, scale=1.0):
|
|
1224
|
+
sf = self._sf
|
|
1225
|
+
scale_f = float(scale)
|
|
1226
|
+
x_f = sf.as_float64(x)
|
|
1227
|
+
if scale_f <= 0.0:
|
|
1228
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1229
|
+
z = sf.abs((x_f - float(loc)) / scale_f)
|
|
1230
|
+
return 0.5 * sf.exp(-z) / scale_f
|
|
1231
|
+
|
|
1232
|
+
def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1233
|
+
return _rvs_laplace(self._sf, size=size, loc=loc, scale=scale)
|
|
1234
|
+
|
|
1235
|
+
|
|
1236
|
+
class LogisticDistributionBase:
|
|
1237
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1238
|
+
self._sf = sf
|
|
1239
|
+
|
|
1240
|
+
def cdf(self, x, *, loc=0.0, scale=1.0):
|
|
1241
|
+
sf = self._sf
|
|
1242
|
+
scale_f = float(scale)
|
|
1243
|
+
x_f = sf.as_float64(x)
|
|
1244
|
+
if scale_f <= 0.0:
|
|
1245
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1246
|
+
z = (x_f - float(loc)) / scale_f
|
|
1247
|
+
return 1.0 / (1.0 + sf.exp(-z))
|
|
1248
|
+
|
|
1249
|
+
def sf(self, x, *, loc=0.0, scale=1.0):
|
|
1250
|
+
sf = self._sf
|
|
1251
|
+
scale_f = float(scale)
|
|
1252
|
+
x_f = sf.as_float64(x)
|
|
1253
|
+
if scale_f <= 0.0:
|
|
1254
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1255
|
+
z = (x_f - float(loc)) / scale_f
|
|
1256
|
+
return 1.0 / (1.0 + sf.exp(z))
|
|
1257
|
+
|
|
1258
|
+
def ppf(self, q, *, loc=0.0, scale=1.0):
|
|
1259
|
+
sf = self._sf
|
|
1260
|
+
scale_f = float(scale)
|
|
1261
|
+
q_f = sf.as_float64(q)
|
|
1262
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1263
|
+
if scale_f <= 0.0:
|
|
1264
|
+
return out
|
|
1265
|
+
out = sf.where(q_f == 0.0, -float("inf"), out)
|
|
1266
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1267
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1268
|
+
return sf.where(valid, float(loc) + scale_f * sf.log(q_f / (1.0 - q_f)), out)
|
|
1269
|
+
|
|
1270
|
+
def isf(self, q, *, loc=0.0, scale=1.0):
|
|
1271
|
+
return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
|
|
1272
|
+
|
|
1273
|
+
def pdf(self, x, *, loc=0.0, scale=1.0):
|
|
1274
|
+
sf = self._sf
|
|
1275
|
+
cdf_x = self.cdf(x, loc=loc, scale=scale)
|
|
1276
|
+
scale_f = float(scale)
|
|
1277
|
+
if scale_f <= 0.0:
|
|
1278
|
+
return cdf_x
|
|
1279
|
+
return cdf_x * (1.0 - cdf_x) / scale_f
|
|
1280
|
+
|
|
1281
|
+
def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1282
|
+
return _rvs_logistic(self._sf, size=size, loc=loc, scale=scale)
|
|
1283
|
+
|
|
1284
|
+
|
|
1285
|
+
class Chi2DistributionBase:
|
|
1286
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1287
|
+
self._sf = sf
|
|
1288
|
+
|
|
1289
|
+
def cdf(self, x, df):
|
|
1290
|
+
sf = self._sf
|
|
1291
|
+
x_f = sf.as_float64(x)
|
|
1292
|
+
df_f = float(df)
|
|
1293
|
+
if df_f <= 0.0:
|
|
1294
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1295
|
+
y = x_f / 2.0
|
|
1296
|
+
return sf.where(x_f <= 0.0, 0.0, sf.gammainc(df_f / 2.0, y))
|
|
1297
|
+
|
|
1298
|
+
def sf(self, x, df):
|
|
1299
|
+
sf = self._sf
|
|
1300
|
+
x_f = sf.as_float64(x)
|
|
1301
|
+
df_f = float(df)
|
|
1302
|
+
if df_f <= 0.0:
|
|
1303
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1304
|
+
y = x_f / 2.0
|
|
1305
|
+
return sf.where(x_f <= 0.0, 1.0, sf.gammaincc(df_f / 2.0, y))
|
|
1306
|
+
|
|
1307
|
+
def ppf(self, q, df):
|
|
1308
|
+
sf = self._sf
|
|
1309
|
+
q_f = sf.as_float64(q)
|
|
1310
|
+
df_f = float(df)
|
|
1311
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1312
|
+
if df_f <= 0.0:
|
|
1313
|
+
return out
|
|
1314
|
+
out = sf.where(q_f == 0.0, 0.0, out)
|
|
1315
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1316
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1317
|
+
return sf.where(valid, 2.0 * sf.gammaincinv(df_f / 2.0, q_f), out)
|
|
1318
|
+
|
|
1319
|
+
def isf(self, q, df):
|
|
1320
|
+
return self.ppf(1.0 - self._sf.as_float64(q), df)
|
|
1321
|
+
|
|
1322
|
+
def pdf(self, x, df):
|
|
1323
|
+
sf = self._sf
|
|
1324
|
+
x_f = sf.as_float64(x)
|
|
1325
|
+
df_f = float(df)
|
|
1326
|
+
if df_f <= 0.0:
|
|
1327
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1328
|
+
y = sf.maximum(x_f, 1e-300)
|
|
1329
|
+
logpdf = ((df_f / 2.0) - 1.0) * sf.log(y) - y / 2.0 - (df_f / 2.0) * sf.log(2.0) - sf.gammaln(df_f / 2.0)
|
|
1330
|
+
return sf.where(x_f > 0.0, sf.exp(logpdf), 0.0)
|
|
1331
|
+
|
|
1332
|
+
def rvs(self, df, *, size=None, dtype=None):
|
|
1333
|
+
return _rvs_chi2(self._sf, df=df, size=size)
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
class GammaDistributionBase:
|
|
1337
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1338
|
+
self._sf = sf
|
|
1339
|
+
|
|
1340
|
+
def cdf(self, x, a, *, loc=0.0, scale=1.0):
|
|
1341
|
+
sf = self._sf
|
|
1342
|
+
x_f = sf.as_float64(x)
|
|
1343
|
+
a_f = float(a)
|
|
1344
|
+
scale_f = float(scale)
|
|
1345
|
+
if a_f <= 0.0 or scale_f <= 0.0:
|
|
1346
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1347
|
+
y = (x_f - float(loc)) / scale_f
|
|
1348
|
+
return sf.where(y <= 0.0, 0.0, sf.gammainc(a_f, y))
|
|
1349
|
+
|
|
1350
|
+
def sf(self, x, a, *, loc=0.0, scale=1.0):
|
|
1351
|
+
sf = self._sf
|
|
1352
|
+
x_f = sf.as_float64(x)
|
|
1353
|
+
a_f = float(a)
|
|
1354
|
+
scale_f = float(scale)
|
|
1355
|
+
if a_f <= 0.0 or scale_f <= 0.0:
|
|
1356
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1357
|
+
y = (x_f - float(loc)) / scale_f
|
|
1358
|
+
return sf.where(y <= 0.0, 1.0, sf.gammaincc(a_f, y))
|
|
1359
|
+
|
|
1360
|
+
def ppf(self, q, a, *, loc=0.0, scale=1.0):
|
|
1361
|
+
sf = self._sf
|
|
1362
|
+
q_f = sf.as_float64(q)
|
|
1363
|
+
a_f = float(a)
|
|
1364
|
+
scale_f = float(scale)
|
|
1365
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1366
|
+
if a_f <= 0.0 or scale_f <= 0.0:
|
|
1367
|
+
return out
|
|
1368
|
+
out = sf.where(q_f == 0.0, float(loc), out)
|
|
1369
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1370
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1371
|
+
return sf.where(valid, float(loc) + scale_f * sf.gammaincinv(a_f, q_f), out)
|
|
1372
|
+
|
|
1373
|
+
def isf(self, q, a, *, loc=0.0, scale=1.0):
|
|
1374
|
+
return self.ppf(1.0 - self._sf.as_float64(q), a, loc=loc, scale=scale)
|
|
1375
|
+
|
|
1376
|
+
def pdf(self, x, a, *, loc=0.0, scale=1.0):
|
|
1377
|
+
sf = self._sf
|
|
1378
|
+
x_f = sf.as_float64(x)
|
|
1379
|
+
a_f = float(a)
|
|
1380
|
+
scale_f = float(scale)
|
|
1381
|
+
if a_f <= 0.0 or scale_f <= 0.0:
|
|
1382
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1383
|
+
y = (x_f - float(loc)) / scale_f
|
|
1384
|
+
y_safe = sf.maximum(y, 1e-300)
|
|
1385
|
+
logpdf = (a_f - 1.0) * sf.log(y_safe) - y_safe - sf.gammaln(a_f) - sf.log(scale_f)
|
|
1386
|
+
return sf.where(y > 0.0, sf.exp(logpdf), 0.0)
|
|
1387
|
+
|
|
1388
|
+
def rvs(self, a, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1389
|
+
return _rvs_gamma(self._sf, a=a, size=size, loc=loc, scale=scale)
|
|
1390
|
+
|
|
1391
|
+
|
|
1392
|
+
class BetaDistributionBase:
|
|
1393
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1394
|
+
self._sf = sf
|
|
1395
|
+
|
|
1396
|
+
def cdf(self, x, a, b, *, loc=0.0, scale=1.0):
|
|
1397
|
+
sf = self._sf
|
|
1398
|
+
x_f = sf.as_float64(x)
|
|
1399
|
+
a_f = float(a)
|
|
1400
|
+
b_f = float(b)
|
|
1401
|
+
scale_f = float(scale)
|
|
1402
|
+
if a_f <= 0.0 or b_f <= 0.0 or scale_f <= 0.0:
|
|
1403
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1404
|
+
y = (x_f - float(loc)) / scale_f
|
|
1405
|
+
core = sf.betainc(a_f, b_f, sf.clip(y, 0.0, 1.0))
|
|
1406
|
+
out = sf.where(y <= 0.0, 0.0, core)
|
|
1407
|
+
return sf.where(y >= 1.0, 1.0, out)
|
|
1408
|
+
|
|
1409
|
+
def sf(self, x, a, b, *, loc=0.0, scale=1.0):
|
|
1410
|
+
return sf_safe_sub(1.0, self.cdf(x, a, b, loc=loc, scale=scale), self._sf)
|
|
1411
|
+
|
|
1412
|
+
def ppf(self, q, a, b, *, loc=0.0, scale=1.0):
|
|
1413
|
+
sf = self._sf
|
|
1414
|
+
q_f = sf.as_float64(q)
|
|
1415
|
+
a_f = float(a)
|
|
1416
|
+
b_f = float(b)
|
|
1417
|
+
scale_f = float(scale)
|
|
1418
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1419
|
+
if a_f <= 0.0 or b_f <= 0.0 or scale_f <= 0.0:
|
|
1420
|
+
return out
|
|
1421
|
+
out = sf.where(q_f == 0.0, float(loc), out)
|
|
1422
|
+
out = sf.where(q_f == 1.0, float(loc) + scale_f, out)
|
|
1423
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1424
|
+
return sf.where(valid, float(loc) + scale_f * sf.betaincinv(a_f, b_f, q_f), out)
|
|
1425
|
+
|
|
1426
|
+
def isf(self, q, a, b, *, loc=0.0, scale=1.0):
|
|
1427
|
+
return self.ppf(1.0 - self._sf.as_float64(q), a, b, loc=loc, scale=scale)
|
|
1428
|
+
|
|
1429
|
+
def pdf(self, x, a, b, *, loc=0.0, scale=1.0):
|
|
1430
|
+
sf = self._sf
|
|
1431
|
+
x_f = sf.as_float64(x)
|
|
1432
|
+
a_f = float(a)
|
|
1433
|
+
b_f = float(b)
|
|
1434
|
+
scale_f = float(scale)
|
|
1435
|
+
if a_f <= 0.0 or b_f <= 0.0 or scale_f <= 0.0:
|
|
1436
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1437
|
+
y = (x_f - float(loc)) / scale_f
|
|
1438
|
+
y_safe = sf.clip(y, 1e-300, 1.0 - 1e-300)
|
|
1439
|
+
betaln = sf.gammaln(a_f) + sf.gammaln(b_f) - sf.gammaln(a_f + b_f)
|
|
1440
|
+
logpdf = (a_f - 1.0) * sf.log(y_safe) + (b_f - 1.0) * sf.log1p(-y_safe) - betaln - sf.log(scale_f)
|
|
1441
|
+
in_support = (y > 0.0) & (y < 1.0)
|
|
1442
|
+
return sf.where(in_support, sf.exp(logpdf), 0.0)
|
|
1443
|
+
|
|
1444
|
+
def rvs(self, a, b, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1445
|
+
return _rvs_beta(self._sf, a=a, b=b, size=size, loc=loc, scale=scale)
|
|
1446
|
+
|
|
1447
|
+
|
|
1448
|
+
class FDistributionBase:
|
|
1449
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1450
|
+
self._sf = sf
|
|
1451
|
+
|
|
1452
|
+
def cdf(self, x, dfn, dfd):
|
|
1453
|
+
sf = self._sf
|
|
1454
|
+
x_f = sf.as_float64(x)
|
|
1455
|
+
dfn_f = float(dfn)
|
|
1456
|
+
dfd_f = float(dfd)
|
|
1457
|
+
if dfn_f <= 0.0 or dfd_f <= 0.0:
|
|
1458
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1459
|
+
z = (dfn_f * sf.maximum(x_f, 0.0)) / (dfn_f * sf.maximum(x_f, 0.0) + dfd_f)
|
|
1460
|
+
core = sf.betainc(dfn_f / 2.0, dfd_f / 2.0, z)
|
|
1461
|
+
return sf.where(x_f <= 0.0, 0.0, core)
|
|
1462
|
+
|
|
1463
|
+
def sf(self, x, dfn, dfd):
|
|
1464
|
+
return sf_safe_sub(1.0, self.cdf(x, dfn, dfd), self._sf)
|
|
1465
|
+
|
|
1466
|
+
def ppf(self, q, dfn, dfd):
|
|
1467
|
+
sf = self._sf
|
|
1468
|
+
q_f = sf.as_float64(q)
|
|
1469
|
+
dfn_f = float(dfn)
|
|
1470
|
+
dfd_f = float(dfd)
|
|
1471
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1472
|
+
if dfn_f <= 0.0 or dfd_f <= 0.0:
|
|
1473
|
+
return out
|
|
1474
|
+
out = sf.where(q_f == 0.0, 0.0, out)
|
|
1475
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1476
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1477
|
+
z = sf.betaincinv(dfn_f / 2.0, dfd_f / 2.0, q_f)
|
|
1478
|
+
return sf.where(valid, (dfd_f * z) / (dfn_f * (1.0 - z)), out)
|
|
1479
|
+
|
|
1480
|
+
def isf(self, q, dfn, dfd):
|
|
1481
|
+
return self.ppf(1.0 - self._sf.as_float64(q), dfn, dfd)
|
|
1482
|
+
|
|
1483
|
+
def pdf(self, x, dfn, dfd):
|
|
1484
|
+
sf = self._sf
|
|
1485
|
+
x_f = sf.as_float64(x)
|
|
1486
|
+
dfn_f = float(dfn)
|
|
1487
|
+
dfd_f = float(dfd)
|
|
1488
|
+
if dfn_f <= 0.0 or dfd_f <= 0.0:
|
|
1489
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1490
|
+
a = dfn_f / 2.0
|
|
1491
|
+
b = dfd_f / 2.0
|
|
1492
|
+
x_safe = sf.maximum(x_f, 1e-300)
|
|
1493
|
+
betaln = sf.gammaln(a) + sf.gammaln(b) - sf.gammaln(a + b)
|
|
1494
|
+
logpdf = a * sf.log(dfn_f / dfd_f) + (a - 1.0) * sf.log(x_safe) - betaln - (a + b) * sf.log1p((dfn_f / dfd_f) * x_safe)
|
|
1495
|
+
return sf.where(x_f > 0.0, sf.exp(logpdf), 0.0)
|
|
1496
|
+
|
|
1497
|
+
def rvs(self, dfn, dfd, *, size=None, dtype=None):
|
|
1498
|
+
return _rvs_f(self._sf, dfn=dfn, dfd=dfd, size=size)
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
class WeibullMinDistributionBase:
|
|
1502
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1503
|
+
self._sf = sf
|
|
1504
|
+
|
|
1505
|
+
def cdf(self, x, c, *, loc=0.0, scale=1.0):
|
|
1506
|
+
sf = self._sf
|
|
1507
|
+
x_f = sf.as_float64(x)
|
|
1508
|
+
c_f = float(c)
|
|
1509
|
+
scale_f = float(scale)
|
|
1510
|
+
if c_f <= 0.0 or scale_f <= 0.0:
|
|
1511
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1512
|
+
y = (x_f - float(loc)) / scale_f
|
|
1513
|
+
yc = sf.power(sf.maximum(y, 0.0), c_f)
|
|
1514
|
+
return sf.where(y <= 0.0, 0.0, 1.0 - sf.exp(-yc))
|
|
1515
|
+
|
|
1516
|
+
def sf(self, x, c, *, loc=0.0, scale=1.0):
|
|
1517
|
+
sf = self._sf
|
|
1518
|
+
x_f = sf.as_float64(x)
|
|
1519
|
+
c_f = float(c)
|
|
1520
|
+
scale_f = float(scale)
|
|
1521
|
+
if c_f <= 0.0 or scale_f <= 0.0:
|
|
1522
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1523
|
+
y = (x_f - float(loc)) / scale_f
|
|
1524
|
+
yc = sf.power(sf.maximum(y, 0.0), c_f)
|
|
1525
|
+
return sf.where(y <= 0.0, 1.0, sf.exp(-yc))
|
|
1526
|
+
|
|
1527
|
+
def ppf(self, q, c, *, loc=0.0, scale=1.0):
|
|
1528
|
+
sf = self._sf
|
|
1529
|
+
q_f = sf.as_float64(q)
|
|
1530
|
+
c_f = float(c)
|
|
1531
|
+
scale_f = float(scale)
|
|
1532
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1533
|
+
if c_f <= 0.0 or scale_f <= 0.0:
|
|
1534
|
+
return out
|
|
1535
|
+
out = sf.where(q_f == 0.0, float(loc), out)
|
|
1536
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1537
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1538
|
+
return sf.where(valid, float(loc) + scale_f * sf.power(-sf.log1p(-q_f), 1.0 / c_f), out)
|
|
1539
|
+
|
|
1540
|
+
def isf(self, q, c, *, loc=0.0, scale=1.0):
|
|
1541
|
+
return self.ppf(1.0 - self._sf.as_float64(q), c, loc=loc, scale=scale)
|
|
1542
|
+
|
|
1543
|
+
def pdf(self, x, c, *, loc=0.0, scale=1.0):
|
|
1544
|
+
sf = self._sf
|
|
1545
|
+
x_f = sf.as_float64(x)
|
|
1546
|
+
c_f = float(c)
|
|
1547
|
+
scale_f = float(scale)
|
|
1548
|
+
if c_f <= 0.0 or scale_f <= 0.0:
|
|
1549
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1550
|
+
y = (x_f - float(loc)) / scale_f
|
|
1551
|
+
y_pos = sf.maximum(y, 1e-300)
|
|
1552
|
+
logpdf = sf.log(c_f / scale_f) + (c_f - 1.0) * sf.log(y_pos) - sf.power(y_pos, c_f)
|
|
1553
|
+
return sf.where(y > 0.0, sf.exp(logpdf), 0.0)
|
|
1554
|
+
|
|
1555
|
+
def rvs(self, c, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1556
|
+
return _rvs_weibull(self._sf, c=c, size=size, loc=loc, scale=scale)
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
class LognormDistributionBase:
|
|
1560
|
+
def __init__(self, sf: SpecialFunctions, norm_dist: NormDistributionBase):
|
|
1561
|
+
self._sf = sf
|
|
1562
|
+
self._norm = norm_dist
|
|
1563
|
+
|
|
1564
|
+
def cdf(self, x, s, *, loc=0.0, scale=1.0):
|
|
1565
|
+
sf = self._sf
|
|
1566
|
+
x_f = sf.as_float64(x)
|
|
1567
|
+
s_f = float(s)
|
|
1568
|
+
scale_f = float(scale)
|
|
1569
|
+
if s_f <= 0.0 or scale_f <= 0.0:
|
|
1570
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1571
|
+
y = (x_f - float(loc)) / scale_f
|
|
1572
|
+
z = sf.log(sf.maximum(y, 1e-300)) / s_f
|
|
1573
|
+
return sf.where(y <= 0.0, 0.0, self._norm._cdf_standard(z))
|
|
1574
|
+
|
|
1575
|
+
def sf(self, x, s, *, loc=0.0, scale=1.0):
|
|
1576
|
+
sf = self._sf
|
|
1577
|
+
x_f = sf.as_float64(x)
|
|
1578
|
+
s_f = float(s)
|
|
1579
|
+
scale_f = float(scale)
|
|
1580
|
+
if s_f <= 0.0 or scale_f <= 0.0:
|
|
1581
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1582
|
+
y = (x_f - float(loc)) / scale_f
|
|
1583
|
+
z = sf.log(sf.maximum(y, 1e-300)) / s_f
|
|
1584
|
+
return sf.where(y <= 0.0, 1.0, self._norm._sf_standard(z))
|
|
1585
|
+
|
|
1586
|
+
def ppf(self, q, s, *, loc=0.0, scale=1.0):
|
|
1587
|
+
sf = self._sf
|
|
1588
|
+
q_f = sf.as_float64(q)
|
|
1589
|
+
s_f = float(s)
|
|
1590
|
+
scale_f = float(scale)
|
|
1591
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1592
|
+
if s_f <= 0.0 or scale_f <= 0.0:
|
|
1593
|
+
return out
|
|
1594
|
+
out = sf.where(q_f == 0.0, float(loc), out)
|
|
1595
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1596
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1597
|
+
return sf.where(valid, float(loc) + scale_f * sf.exp(s_f * self._norm._ppf_standard(q_f)), out)
|
|
1598
|
+
|
|
1599
|
+
def isf(self, q, s, *, loc=0.0, scale=1.0):
|
|
1600
|
+
return self.ppf(1.0 - self._sf.as_float64(q), s, loc=loc, scale=scale)
|
|
1601
|
+
|
|
1602
|
+
def pdf(self, x, s, *, loc=0.0, scale=1.0):
|
|
1603
|
+
sf = self._sf
|
|
1604
|
+
x_f = sf.as_float64(x)
|
|
1605
|
+
s_f = float(s)
|
|
1606
|
+
scale_f = float(scale)
|
|
1607
|
+
if s_f <= 0.0 or scale_f <= 0.0:
|
|
1608
|
+
return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
|
|
1609
|
+
y = (x_f - float(loc)) / scale_f
|
|
1610
|
+
y_pos = sf.maximum(y, 1e-300)
|
|
1611
|
+
z = sf.log(y_pos) / s_f
|
|
1612
|
+
logpdf = -0.5 * sf.square(z) - sf.log(y_pos * s_f * sf.sqrt(2.0 * sf.pi)) - sf.log(scale_f)
|
|
1613
|
+
return sf.where(y > 0.0, sf.exp(logpdf), 0.0)
|
|
1614
|
+
|
|
1615
|
+
def rvs(self, s, *, size=None, loc=0.0, scale=1.0, dtype=None):
|
|
1616
|
+
return _rvs_lognorm(self._sf, s=s, size=size, loc=loc, scale=scale)
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
class PoissonDistributionBase:
|
|
1620
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1621
|
+
self._sf = sf
|
|
1622
|
+
|
|
1623
|
+
def _ppf_search(self, q, mu):
|
|
1624
|
+
sf = self._sf
|
|
1625
|
+
q_f = sf.as_float64(q)
|
|
1626
|
+
mu_f = float(mu)
|
|
1627
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1628
|
+
if mu_f < 0.0:
|
|
1629
|
+
return out
|
|
1630
|
+
out = sf.where(q_f == 0.0, -1.0, out)
|
|
1631
|
+
out = sf.where(q_f == 1.0, float("inf"), out)
|
|
1632
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1633
|
+
if not bool(sf.any(valid)):
|
|
1634
|
+
return out
|
|
1635
|
+
hi0 = float(max(1.0, np.ceil(mu_f + 10.0 * np.sqrt(mu_f + 1.0) + 10.0)))
|
|
1636
|
+
low = sf.where(q_f * 0 + 1, -1.0, -1.0)
|
|
1637
|
+
high = sf.where(q_f * 0 + 1, hi0, hi0)
|
|
1638
|
+
for _ in range(16):
|
|
1639
|
+
cdf_high = sf.where(high < 0.0, 0.0, sf.gammaincc(high + 1.0, mu_f))
|
|
1640
|
+
need_expand = valid & (cdf_high < q_f)
|
|
1641
|
+
high = sf.where(need_expand, sf.maximum(high * 2.0 + 1.0, 1.0), high)
|
|
1642
|
+
max_high_f = float(np.max(sf.to_numpy(sf.where(valid, high, 0.0))))
|
|
1643
|
+
steps = int(np.ceil(np.log2(max(max_high_f + 2.0, 2.0)))) + 2
|
|
1644
|
+
for _ in range(max(1, steps)):
|
|
1645
|
+
mid = sf.floor((low + high) / 2.0)
|
|
1646
|
+
cdf_mid = sf.where(mid < 0.0, 0.0, sf.gammaincc(mid + 1.0, mu_f))
|
|
1647
|
+
move_right = valid & (cdf_mid < q_f)
|
|
1648
|
+
low = sf.where(move_right, mid, low)
|
|
1649
|
+
high = sf.where(valid & (~move_right), mid, high)
|
|
1650
|
+
k = sf.floor(high)
|
|
1651
|
+
cdf_k = sf.where(k < 0.0, 0.0, sf.gammaincc(k + 1.0, mu_f))
|
|
1652
|
+
k = sf.where(valid & (cdf_k < q_f), k + 1.0, k)
|
|
1653
|
+
km1 = k - 1.0
|
|
1654
|
+
cdf_km1 = sf.where(km1 < 0.0, 0.0, sf.gammaincc(k, mu_f))
|
|
1655
|
+
return sf.where(valid & (km1 >= -1.0) & (cdf_km1 >= q_f), km1, sf.where(valid, k, out))
|
|
1656
|
+
|
|
1657
|
+
def pmf(self, k, mu, *, loc=0):
|
|
1658
|
+
sf = self._sf
|
|
1659
|
+
k_f = sf.as_float64(k) - float(loc)
|
|
1660
|
+
mu_f = float(mu)
|
|
1661
|
+
if mu_f < 0.0:
|
|
1662
|
+
return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
|
|
1663
|
+
k_floor = sf.floor(k_f)
|
|
1664
|
+
is_int = (k_floor == k_f)
|
|
1665
|
+
valid = (k_f >= 0.0) & is_int
|
|
1666
|
+
k_safe = sf.maximum(k_floor, 0.0)
|
|
1667
|
+
logpmf = k_safe * sf.log(sf.maximum(mu_f, 1e-300)) - mu_f - sf.gammaln(k_safe + 1.0)
|
|
1668
|
+
return sf.where(valid, sf.exp(logpmf), 0.0)
|
|
1669
|
+
|
|
1670
|
+
def cdf(self, k, mu, *, loc=0):
|
|
1671
|
+
sf = self._sf
|
|
1672
|
+
k_f = sf.as_float64(k) - float(loc)
|
|
1673
|
+
mu_f = float(mu)
|
|
1674
|
+
if mu_f < 0.0:
|
|
1675
|
+
return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
|
|
1676
|
+
k_floor = sf.floor(k_f)
|
|
1677
|
+
return sf.where(k_floor < 0.0, 0.0, sf.gammaincc(k_floor + 1.0, mu_f))
|
|
1678
|
+
|
|
1679
|
+
def sf(self, k, mu, *, loc=0):
|
|
1680
|
+
return sf_safe_sub(1.0, self.cdf(k, mu, loc=loc), self._sf)
|
|
1681
|
+
|
|
1682
|
+
def ppf(self, q, mu, *, loc=0):
|
|
1683
|
+
sf = self._sf
|
|
1684
|
+
loc_f = float(loc)
|
|
1685
|
+
q_f = sf.as_float64(q)
|
|
1686
|
+
return self._ppf_search(q_f, mu) + loc_f
|
|
1687
|
+
|
|
1688
|
+
def isf(self, q, mu, *, loc=0):
|
|
1689
|
+
return self.ppf(1.0 - self._sf.as_float64(q), mu, loc=loc)
|
|
1690
|
+
|
|
1691
|
+
def rvs(self, mu, *, size=None, loc=0, dtype=None):
|
|
1692
|
+
return _rvs_poisson(self._sf, mu=mu, size=size, loc=loc)
|
|
1693
|
+
|
|
1694
|
+
|
|
1695
|
+
class BinomDistributionBase:
|
|
1696
|
+
def __init__(self, sf: SpecialFunctions):
|
|
1697
|
+
self._sf = sf
|
|
1698
|
+
|
|
1699
|
+
def _ppf_search(self, q, n, p):
|
|
1700
|
+
sf = self._sf
|
|
1701
|
+
q_f = sf.as_float64(q)
|
|
1702
|
+
n_i = int(n)
|
|
1703
|
+
p_f = float(p)
|
|
1704
|
+
out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
|
|
1705
|
+
if n_i < 0 or p_f < 0.0 or p_f > 1.0:
|
|
1706
|
+
return out
|
|
1707
|
+
out = sf.where(q_f == 0.0, -1.0, out)
|
|
1708
|
+
out = sf.where(q_f == 1.0, float(n_i), out)
|
|
1709
|
+
valid = (q_f > 0.0) & (q_f < 1.0)
|
|
1710
|
+
if not bool(sf.any(valid)):
|
|
1711
|
+
return out
|
|
1712
|
+
low = sf.where(q_f * 0 + 1, -1.0, -1.0)
|
|
1713
|
+
high = sf.where(q_f * 0 + 1, float(n_i), float(n_i))
|
|
1714
|
+
steps = int(np.ceil(np.log2(max(n_i + 2, 2)))) + 2
|
|
1715
|
+
for _ in range(max(1, steps)):
|
|
1716
|
+
mid = sf.floor((low + high) / 2.0)
|
|
1717
|
+
cdf_mid = self.cdf(mid, n_i, p_f, loc=0)
|
|
1718
|
+
move_right = valid & (cdf_mid < q_f)
|
|
1719
|
+
low = sf.where(move_right, mid, low)
|
|
1720
|
+
high = sf.where(valid & (~move_right), mid, high)
|
|
1721
|
+
k = sf.floor(high)
|
|
1722
|
+
cdf_k = self.cdf(k, n_i, p_f, loc=0)
|
|
1723
|
+
k = sf.where(valid & (cdf_k < q_f), k + 1.0, k)
|
|
1724
|
+
km1 = k - 1.0
|
|
1725
|
+
cdf_km1 = self.cdf(km1, n_i, p_f, loc=0)
|
|
1726
|
+
return sf.where(valid & (km1 >= -1.0) & (cdf_km1 >= q_f), km1, sf.where(valid, k, out))
|
|
1727
|
+
|
|
1728
|
+
def pmf(self, k, n, p, *, loc=0):
|
|
1729
|
+
sf = self._sf
|
|
1730
|
+
n_i = int(n)
|
|
1731
|
+
p_f = float(p)
|
|
1732
|
+
k_f = sf.as_float64(k) - float(loc)
|
|
1733
|
+
if n_i < 0 or p_f < 0.0 or p_f > 1.0:
|
|
1734
|
+
return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
|
|
1735
|
+
k_floor = sf.floor(k_f)
|
|
1736
|
+
is_int = (k_floor == k_f)
|
|
1737
|
+
valid = (k_floor >= 0.0) & (k_floor <= float(n_i)) & is_int
|
|
1738
|
+
k_safe = sf.clip(k_floor, 0.0, float(n_i))
|
|
1739
|
+
logcoef = sf.gammaln(n_i + 1.0) - sf.gammaln(k_safe + 1.0) - sf.gammaln(n_i - k_safe + 1.0)
|
|
1740
|
+
logpmf = logcoef + k_safe * sf.log(sf.maximum(p_f, 1e-300)) + (n_i - k_safe) * sf.log(sf.maximum(1.0 - p_f, 1e-300))
|
|
1741
|
+
return sf.where(valid, sf.exp(logpmf), 0.0)
|
|
1742
|
+
|
|
1743
|
+
def cdf(self, k, n, p, *, loc=0):
|
|
1744
|
+
sf = self._sf
|
|
1745
|
+
n_i = int(n)
|
|
1746
|
+
p_f = float(p)
|
|
1747
|
+
k_f = sf.as_float64(k) - float(loc)
|
|
1748
|
+
if n_i < 0 or p_f < 0.0 or p_f > 1.0:
|
|
1749
|
+
return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
|
|
1750
|
+
k_floor = sf.floor(k_f)
|
|
1751
|
+
out = sf.where(k_floor < 0.0, 0.0, sf.betainc(n_i - k_floor, k_floor + 1.0, 1.0 - p_f))
|
|
1752
|
+
return sf.where(k_floor >= float(n_i), 1.0, out)
|
|
1753
|
+
|
|
1754
|
+
def sf(self, k, n, p, *, loc=0):
|
|
1755
|
+
return sf_safe_sub(1.0, self.cdf(k, n, p, loc=loc), self._sf)
|
|
1756
|
+
|
|
1757
|
+
def ppf(self, q, n, p, *, loc=0):
|
|
1758
|
+
sf = self._sf
|
|
1759
|
+
loc_f = float(loc)
|
|
1760
|
+
q_f = sf.as_float64(q)
|
|
1761
|
+
return self._ppf_search(q_f, n, p) + loc_f
|
|
1762
|
+
|
|
1763
|
+
def isf(self, q, n, p, *, loc=0):
|
|
1764
|
+
return self.ppf(1.0 - self._sf.as_float64(q), n, p, loc=loc)
|
|
1765
|
+
|
|
1766
|
+
def rvs(self, n, p, *, size=None, loc=0, dtype=None):
|
|
1767
|
+
return _rvs_binom(self._sf, n=n, p=p, size=size, loc=loc)
|
|
1768
|
+
|
|
1769
|
+
|
|
1770
|
+
# =============================================================================
|
|
1771
|
+
# Approximation-based inverse special functions
|
|
1772
|
+
# =============================================================================
|
|
1773
|
+
# These provide fast, vectorized (numpy/cupy/torch) initial guesses for
|
|
1774
|
+
# gammaincinv and betaincinv. Used by all three backends.
|
|
1775
|
+
# Reference:
|
|
1776
|
+
# - Wilson-Hilferty (1931) for chi2/gamma
|
|
1777
|
+
# - logit-normal approximation for beta (DiDonato & Morris 1996)
|
|
1778
|
+
# - Newton refinement for 1-2 extra correct digits per step
|
|
1779
|
+
|
|
1780
|
+
def _gammaincinv_wilson_hilferty(a, q):
|
|
1781
|
+
"""Wilson-Hilferty cube-root approximation for gammaincinv(a, q).
|
|
1782
|
+
|
|
1783
|
+
Returns an initial guess x ≈ gammaincinv(a, q).
|
|
1784
|
+
Works for a > 0, q ∈ (0, 1). Best for a >= 1.
|
|
1785
|
+
"""
|
|
1786
|
+
import scipy.special as _scsp
|
|
1787
|
+
a_f = float(a)
|
|
1788
|
+
c = 1.0 / (9.0 * a_f)
|
|
1789
|
+
s = math.sqrt(c)
|
|
1790
|
+
z = -math.sqrt(2.0) * _scsp.erfcinv(2.0 * np.asarray(q, dtype=np.float64))
|
|
1791
|
+
x = a_f * (1.0 - c + z * s) ** 3
|
|
1792
|
+
return np.where(x > 0, x, 1e-10)
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def _gammaincinv_a_small(a, q):
|
|
1796
|
+
"""Approximation for gammaincinv(a, q) when a < 1.
|
|
1797
|
+
|
|
1798
|
+
Uses power series for small q and normal approximation for large q.
|
|
1799
|
+
"""
|
|
1800
|
+
import scipy.special as _scsp
|
|
1801
|
+
q_arr = np.asarray(q, dtype=np.float64)
|
|
1802
|
+
a_f = float(a)
|
|
1803
|
+
g_a1 = math.exp(_scsp.gammaln(a_f + 1.0))
|
|
1804
|
+
# Series: P(a,x) ≈ x^a / Gamma(a+1) for small x → x ≈ (q * Gamma(a+1))^(1/a)
|
|
1805
|
+
x_small = (q_arr * g_a1) ** (1.0 / a_f)
|
|
1806
|
+
# For large q, use Wilson-Hilferty even though it's designed for a >= 1
|
|
1807
|
+
x_large = _gammaincinv_wilson_hilferty(a_f, q_arr)
|
|
1808
|
+
# Blend: use small approx when x_small < 1, large approx otherwise
|
|
1809
|
+
return np.where(x_small < 1.0, x_small, x_large)
|
|
1810
|
+
|
|
1811
|
+
|
|
1812
|
+
def _gammaincinv_newton_numpy(a, q, x0, n_iter=3):
|
|
1813
|
+
"""Refine gammaincinv(a, q) using Newton's method.
|
|
1814
|
+
|
|
1815
|
+
x0: initial guess (numpy array)
|
|
1816
|
+
n_iter: number of Newton refinement steps (default 3, each gives ~1 extra digit)
|
|
1817
|
+
"""
|
|
1818
|
+
import scipy.special as scsp
|
|
1819
|
+
x = np.asarray(x0, dtype=np.float64)
|
|
1820
|
+
a_f = float(a)
|
|
1821
|
+
log_ga = math.lgamma(a_f)
|
|
1822
|
+
q_arr = np.asarray(q, dtype=np.float64)
|
|
1823
|
+
for _ in range(n_iter):
|
|
1824
|
+
p = scsp.gammainc(a_f, x)
|
|
1825
|
+
diff = p - q_arr
|
|
1826
|
+
if np.max(np.abs(diff)) < 1e-14:
|
|
1827
|
+
break
|
|
1828
|
+
log_deriv = (a_f - 1.0) * np.log(np.clip(x, 1e-300, None)) - x - log_ga
|
|
1829
|
+
deriv = np.exp(log_deriv)
|
|
1830
|
+
deriv = np.clip(deriv, 1e-300, 1e300)
|
|
1831
|
+
x = x - diff / deriv
|
|
1832
|
+
x = np.clip(x, 1e-15, 1e6)
|
|
1833
|
+
return x
|
|
1834
|
+
|
|
1835
|
+
|
|
1836
|
+
def _betaincinv_logit_approx(a, b, q):
|
|
1837
|
+
"""Logit-normal approximation for betaincinv(a, b, q).
|
|
1838
|
+
|
|
1839
|
+
For Beta(a, b), the logit transform log(X/(1-X)) is approximately normal
|
|
1840
|
+
with mean ψ(a) - ψ(b) and variance 1/a + 1/b (Digamma approximation).
|
|
1841
|
+
"""
|
|
1842
|
+
import scipy.special as _scsp
|
|
1843
|
+
a_f, b_f = float(a), float(b)
|
|
1844
|
+
mu = _scsp.digamma(a_f) - _scsp.digamma(b_f)
|
|
1845
|
+
sigma2 = 1.0 / a_f + 1.0 / b_f
|
|
1846
|
+
sigma = math.sqrt(sigma2)
|
|
1847
|
+
z = -math.sqrt(2.0) * _scsp.erfcinv(2.0 * np.asarray(q, dtype=np.float64))
|
|
1848
|
+
logit_q = mu + sigma * z
|
|
1849
|
+
x = 1.0 / (1.0 + np.exp(-logit_q))
|
|
1850
|
+
return np.clip(x, 1e-15, 1.0 - 1e-15)
|
|
1851
|
+
|
|
1852
|
+
|
|
1853
|
+
def _betaincinv_newton_numpy(a, b, q, x0, n_iter=3):
|
|
1854
|
+
"""Refine betaincinv(a, b, q) using Newton's method.
|
|
1855
|
+
|
|
1856
|
+
x0: initial guess (numpy array)
|
|
1857
|
+
n_iter: number of Newton refinement steps
|
|
1858
|
+
"""
|
|
1859
|
+
import scipy.special as scsp
|
|
1860
|
+
x = np.asarray(x0, dtype=np.float64)
|
|
1861
|
+
a_f, b_f = float(a), float(b)
|
|
1862
|
+
log_beta = math.lgamma(a_f) + math.lgamma(b_f) - math.lgamma(a_f + b_f)
|
|
1863
|
+
q_arr = np.asarray(q, dtype=np.float64)
|
|
1864
|
+
for _ in range(n_iter):
|
|
1865
|
+
p = scsp.betainc(a_f, b_f, x)
|
|
1866
|
+
diff = p - q_arr
|
|
1867
|
+
if np.max(np.abs(diff)) < 1e-14:
|
|
1868
|
+
break
|
|
1869
|
+
log_deriv = (a_f - 1.0) * np.log(np.clip(x, 1e-300, None)) + \
|
|
1870
|
+
(b_f - 1.0) * np.log(np.clip(1.0 - x, 1e-300, None)) - log_beta
|
|
1871
|
+
deriv = np.exp(log_deriv)
|
|
1872
|
+
deriv = np.clip(deriv, 1e-300, 1e300)
|
|
1873
|
+
x = x - diff / deriv
|
|
1874
|
+
x = np.clip(x, 1e-15, 1.0 - 1e-15)
|
|
1875
|
+
return x
|
|
1876
|
+
|
|
1877
|
+
|
|
1878
|
+
def _t_ppf_cornish_fisher(df, q):
|
|
1879
|
+
"""Cornish-Fisher expansion for Student-t quantile function.
|
|
1880
|
+
|
|
1881
|
+
Avoids the expensive betaincinv call.
|
|
1882
|
+
Accuracy: ~1e-10 for df >= 2, ~1e-6 for df < 2.
|
|
1883
|
+
"""
|
|
1884
|
+
import scipy.special as _scsp
|
|
1885
|
+
z = -math.sqrt(2.0) * _scsp.erfcinv(2.0 * np.asarray(q, dtype=np.float64))
|
|
1886
|
+
z2 = z * z
|
|
1887
|
+
z3 = z2 * z
|
|
1888
|
+
z5 = z3 * z2
|
|
1889
|
+
df_f = float(df)
|
|
1890
|
+
# Hall (1992) approximation for t quantile
|
|
1891
|
+
d1 = 1.0 / (4.0 * df_f)
|
|
1892
|
+
d2 = 1.0 / (96.0 * df_f * df_f)
|
|
1893
|
+
d3 = 1.0 / (384.0 * df_f * df_f * df_f)
|
|
1894
|
+
d4 = 1.0 / (9216.0 * df_f * df_f * df_f)
|
|
1895
|
+
t_approx = z + (z3 + z) * d1 + (5.0 * z5 + 16.0 * z3 + 3.0 * z) * d2 + \
|
|
1896
|
+
(3.0 * z5 + 19.0 * z3 + 17.0 * z) * d3 + \
|
|
1897
|
+
(79.0 * z5 + 462.0 * z3 + 579.0 * z) * d4
|
|
1898
|
+
return np.asarray(t_approx)
|
|
1899
|
+
|
|
1900
|
+
|
|
1901
|
+
def _t_ppf_hall_approx(df, q):
|
|
1902
|
+
"""Hall's (1992) approximation for t quantile.
|
|
1903
|
+
|
|
1904
|
+
More accurate than basic Cornish-Fisher, error ~1e-14 for df >= 1.
|
|
1905
|
+
Uses the inverse of the regularized incomplete beta via a
|
|
1906
|
+
transformed normal approximation.
|
|
1907
|
+
"""
|
|
1908
|
+
import scipy.special as scsp
|
|
1909
|
+
df_f = float(df)
|
|
1910
|
+
# Fisher-Cornish expansion
|
|
1911
|
+
z = -math.sqrt(2.0) * scsp.erfcinv(2.0 * q)
|
|
1912
|
+
z2 = z * z
|
|
1913
|
+
z3 = z2 * z
|
|
1914
|
+
z4 = z3 * z
|
|
1915
|
+
z5 = z4 * z
|
|
1916
|
+
|
|
1917
|
+
# Coefficients from Hall (1992) Biometrika
|
|
1918
|
+
a1 = 1.0 / 4.0
|
|
1919
|
+
a2 = 1.0 / 96.0
|
|
1920
|
+
a3 = -1.0 / 96.0
|
|
1921
|
+
a4 = -1.0 / 384.0
|
|
1922
|
+
|
|
1923
|
+
nu = df_f
|
|
1924
|
+
t = z + (z3 + z) * a1 / nu + \
|
|
1925
|
+
(5.0 * z5 + 16.0 * z3 + 3.0 * z) * a2 / (nu * nu) + \
|
|
1926
|
+
(3.0 * z5 + 19.0 * z3 + 17.0 * z) * a3 / (nu * nu * nu) + \
|
|
1927
|
+
(79.0 * z5 + 462.0 * z3 + 579.0 * z) * a4 / (nu * nu * nu * nu)
|
|
1928
|
+
return np.asarray(t)
|
|
1929
|
+
|
|
1930
|
+
|
|
1931
|
+
def _t_ppf_wilson_hilferty_approx(df, q):
|
|
1932
|
+
"""Wilson-Hilferty-type approximation for t PPF.
|
|
1933
|
+
|
|
1934
|
+
Uses the relationship t^2 ~ df * F(1, df) and approximates the
|
|
1935
|
+
F quantile via chi2 approximation.
|
|
1936
|
+
Best for |z| < 5 and df > 1.
|
|
1937
|
+
"""
|
|
1938
|
+
import scipy.special as scsp
|
|
1939
|
+
df_f = float(df)
|
|
1940
|
+
q2 = q # keep signed
|
|
1941
|
+
# For signed quantiles, work with |z| and restore sign
|
|
1942
|
+
sign = np.sign(q2 - 0.5)
|
|
1943
|
+
sign = np.where(sign == 0, 1.0, sign)
|
|
1944
|
+
q_abs = np.abs(q2 - 0.5) + 0.5 # always in (0.5, 1]
|
|
1945
|
+
|
|
1946
|
+
# z = Φ^{-1}(q)
|
|
1947
|
+
z = -math.sqrt(2.0) * scsp.erfcinv(2.0 * q_abs)
|
|
1948
|
+
z = z * sign
|
|
1949
|
+
|
|
1950
|
+
# Refinement: t ≈ z * (1 - 1/(4*df) + z^2/(96*df^2))^{-1/2} ...
|
|
1951
|
+
# This is a simplified version of the Hall approximation
|
|
1952
|
+
z2 = z * z
|
|
1953
|
+
t = z * (1.0 + (z2 - 1.0) / (4.0 * df_f) + (5.0 * z2 * (z2 + 7.0) - 2.0) / (96.0 * df_f * df_f))
|
|
1954
|
+
return np.asarray(t)
|
|
1955
|
+
|
|
1956
|
+
|
|
1957
|
+
# =============================================================================
|
|
1958
|
+
# Scalar-function helpers (atan, log, log1p, square, abs, power, floor)
|
|
1959
|
+
# =============================================================================
|
|
1960
|
+
# These need per-backend implementations. We store them on the sf objects
|
|
1961
|
+
# but provide fallbacks for protocols that don't define them.
|
|
1962
|
+
|
|
1963
|
+
def _scalar_op(sf, name, *args):
|
|
1964
|
+
"""Call a scalar operation, falling back to numpy if not on sf."""
|
|
1965
|
+
fn = getattr(sf, name, None)
|
|
1966
|
+
if fn is not None:
|
|
1967
|
+
return fn(*args)
|
|
1968
|
+
np_fn = getattr(np, name)
|
|
1969
|
+
return np_fn(*[np.asarray(a) for a in args])
|
|
1970
|
+
|
|
1971
|
+
|
|
1972
|
+
class _SpecialFunctionsMixin:
|
|
1973
|
+
"""Mixin adding scalar ops to the three concrete SpecialFunctions impls."""
|
|
1974
|
+
|
|
1975
|
+
def sqrt(self, x):
|
|
1976
|
+
return _scalar_op(type(self), "sqrt", x)
|
|
1977
|
+
|
|
1978
|
+
def log(self, x):
|
|
1979
|
+
return _scalar_op(type(self), "log", x)
|
|
1980
|
+
|
|
1981
|
+
def log1p(self, x):
|
|
1982
|
+
return _scalar_op(type(self), "log1p", x)
|
|
1983
|
+
|
|
1984
|
+
def square(self, x):
|
|
1985
|
+
return _scalar_op(type(self), "square", x)
|
|
1986
|
+
|
|
1987
|
+
def abs(self, x):
|
|
1988
|
+
return _scalar_op(type(self), "abs", x)
|
|
1989
|
+
|
|
1990
|
+
def power(self, x, y):
|
|
1991
|
+
return _scalar_op(type(self), "power", x, y)
|
|
1992
|
+
|
|
1993
|
+
def floor(self, x):
|
|
1994
|
+
return _scalar_op(type(self), "floor", x)
|
|
1995
|
+
|
|
1996
|
+
def atan(self, x):
|
|
1997
|
+
return _scalar_op(type(self), "arctan", x)
|
|
1998
|
+
|
|
1999
|
+
def exp(self, x):
|
|
2000
|
+
return _scalar_op(type(self), "exp", x)
|
|
2001
|
+
|
|
2002
|
+
def maximum(self, x, y):
|
|
2003
|
+
return _scalar_op(type(self), "maximum", x, y)
|
|
2004
|
+
|
|
2005
|
+
def minimum(self, x, y):
|
|
2006
|
+
return _scalar_op(type(self), "minimum", x, y)
|
|
2007
|
+
|
|
2008
|
+
def any(self, x):
|
|
2009
|
+
return _scalar_op(type(self), "any", x)
|
|
2010
|
+
|
|
2011
|
+
def to_numpy(self, x):
|
|
2012
|
+
return np.asarray(x)
|
|
2013
|
+
|
|
2014
|
+
|
|
2015
|
+
# Patch scalar ops into each concrete implementation
|
|
2016
|
+
for _cls in (CuPySpecialFunctions, TorchSpecialFunctions, ScipySpecialFunctions):
|
|
2017
|
+
for _name in ("sqrt", "log", "log1p", "square", "abs", "power", "floor", "atan", "tan", "exp", "maximum", "minimum", "any", "to_numpy"):
|
|
2018
|
+
_np_name = {"atan": "arctan", "tan": "tan"}.get(_name, _name)
|
|
2019
|
+
if _name in ("power", "maximum", "minimum"):
|
|
2020
|
+
def _make_bin(_n=_np_name):
|
|
2021
|
+
return lambda self, x, y: getattr(np, _n)(np.asarray(x), np.asarray(y))
|
|
2022
|
+
setattr(_cls, _name, _make_bin())
|
|
2023
|
+
elif _name == "any":
|
|
2024
|
+
def _make_any():
|
|
2025
|
+
return lambda self, x: np.any(np.asarray(x))
|
|
2026
|
+
setattr(_cls, _name, _make_any())
|
|
2027
|
+
else:
|
|
2028
|
+
def _make_fn(_n=_np_name):
|
|
2029
|
+
return lambda self, x: getattr(np, _n)(np.asarray(x))
|
|
2030
|
+
setattr(_cls, _name, _make_fn())
|
|
2031
|
+
|
|
2032
|
+
# Now override with backend-native versions
|
|
2033
|
+
# Fix to_numpy for ScipySpecialFunctions (np.to_numpy doesn't exist)
|
|
2034
|
+
ScipySpecialFunctions.to_numpy = lambda self, x: np.asarray(x)
|
|
2035
|
+
|
|
2036
|
+
CuPySpecialFunctions.sqrt = lambda self, x: self._cp.sqrt(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2037
|
+
CuPySpecialFunctions.log = lambda self, x: self._cp.log(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2038
|
+
CuPySpecialFunctions.log1p = lambda self, x: self._cp.log1p(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2039
|
+
CuPySpecialFunctions.square = lambda self, x: self._cp.square(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2040
|
+
CuPySpecialFunctions.abs = lambda self, x: self._cp.abs(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2041
|
+
CuPySpecialFunctions.power = lambda self, x, y: self._cp.power(self._cp.asarray(x, dtype=self._cp.float64), self._cp.asarray(y, dtype=self._cp.float64))
|
|
2042
|
+
CuPySpecialFunctions.floor = lambda self, x: self._cp.floor(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2043
|
+
CuPySpecialFunctions.atan = lambda self, x: self._cp.arctan(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2044
|
+
CuPySpecialFunctions.tan = lambda self, x: self._cp.tan(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2045
|
+
CuPySpecialFunctions.exp = lambda self, x: self._cp.exp(self._cp.asarray(x, dtype=self._cp.float64))
|
|
2046
|
+
CuPySpecialFunctions.maximum = lambda self, x, y: self._cp.maximum(self._cp.asarray(x, dtype=self._cp.float64), self._cp.asarray(y, dtype=self._cp.float64))
|
|
2047
|
+
CuPySpecialFunctions.minimum = lambda self, x, y: self._cp.minimum(self._cp.asarray(x, dtype=self._cp.float64), self._cp.asarray(y, dtype=self._cp.float64))
|
|
2048
|
+
CuPySpecialFunctions.any = lambda self, x: self._cp.any(x)
|
|
2049
|
+
CuPySpecialFunctions.to_numpy = lambda self, x: self._cp.asnumpy(x) if hasattr(x, 'get') else np.asarray(x)
|
|
2050
|
+
|
|
2051
|
+
TorchSpecialFunctions.sqrt = lambda self, x: self._torch.sqrt(self._as_tensor(x))
|
|
2052
|
+
TorchSpecialFunctions.log = lambda self, x: self._torch.log(self._as_tensor(x))
|
|
2053
|
+
TorchSpecialFunctions.log1p = lambda self, x: self._torch.log1p(self._as_tensor(x))
|
|
2054
|
+
TorchSpecialFunctions.square = lambda self, x: self._torch.square(self._as_tensor(x))
|
|
2055
|
+
TorchSpecialFunctions.abs = lambda self, x: self._torch.abs(self._as_tensor(x))
|
|
2056
|
+
TorchSpecialFunctions.power = lambda self, x, y: self._torch.pow(self._as_tensor(x), self._as_tensor(y))
|
|
2057
|
+
TorchSpecialFunctions.floor = lambda self, x: self._torch.floor(self._as_tensor(x))
|
|
2058
|
+
TorchSpecialFunctions.atan = lambda self, x: self._torch.atan(self._as_tensor(x))
|
|
2059
|
+
TorchSpecialFunctions.tan = lambda self, x: self._torch.tan(self._as_tensor(x))
|
|
2060
|
+
TorchSpecialFunctions.exp = lambda self, x: self._torch.exp(self._as_tensor(x))
|
|
2061
|
+
TorchSpecialFunctions.maximum = lambda self, x, y: self._torch.maximum(self._as_tensor(x), self._as_tensor(y))
|
|
2062
|
+
TorchSpecialFunctions.minimum = lambda self, x, y: self._torch.minimum(self._as_tensor(x), self._as_tensor(y))
|
|
2063
|
+
TorchSpecialFunctions.any = lambda self, x: self._torch.any(x)
|
|
2064
|
+
TorchSpecialFunctions.to_numpy = lambda self, x: x.detach().cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x)
|
|
2065
|
+
|
|
2066
|
+
|
|
2067
|
+
# =============================================================================
|
|
2068
|
+
# Safe scalar operations (clip 1-x to [0,1] for SF computation)
|
|
2069
|
+
# =============================================================================
|
|
2070
|
+
|
|
2071
|
+
def sf_safe_sub(val, other, sf):
|
|
2072
|
+
"""Compute val - other, clamping result to [0, 1]."""
|
|
2073
|
+
return sf.clip(val - other, 0.0, 1.0)
|
|
2074
|
+
|
|
2075
|
+
|
|
2076
|
+
def sf_safe_mul(val, factor, sf):
|
|
2077
|
+
"""Compute val * factor, clamping result to [0, 1]."""
|
|
2078
|
+
return sf.clip(val * factor, 0.0, 1.0)
|
|
2079
|
+
|
|
2080
|
+
|
|
2081
|
+
# =============================================================================
|
|
2082
|
+
# Random variate helpers (pure-numpy CPU generation, then converted)
|
|
2083
|
+
# =============================================================================
|
|
2084
|
+
|
|
2085
|
+
def _rvs_normal(sf, *, size, loc, scale):
|
|
2086
|
+
out = np.random.normal(loc=float(loc), scale=float(scale), size=size)
|
|
2087
|
+
if hasattr(sf, 'as_float64'):
|
|
2088
|
+
return sf.as_float64(out)
|
|
2089
|
+
return out
|
|
2090
|
+
|
|
2091
|
+
|
|
2092
|
+
def _rvs_t(sf, *, df, size, loc, scale):
|
|
2093
|
+
# Use scipy fallback for t-distribution rvs
|
|
2094
|
+
import scipy.stats as sps
|
|
2095
|
+
out = sps.t.rvs(df=float(df), size=size, loc=float(loc), scale=float(scale))
|
|
2096
|
+
if hasattr(sf, 'as_float64'):
|
|
2097
|
+
return sf.as_float64(out)
|
|
2098
|
+
return out
|
|
2099
|
+
|
|
2100
|
+
|
|
2101
|
+
def _rvs_uniform(sf, *, size, loc, scale):
|
|
2102
|
+
out = np.random.uniform(low=float(loc), high=float(loc) + float(scale), size=size)
|
|
2103
|
+
if hasattr(sf, 'as_float64'):
|
|
2104
|
+
return sf.as_float64(out)
|
|
2105
|
+
return out
|
|
2106
|
+
|
|
2107
|
+
|
|
2108
|
+
def _rvs_expon(sf, *, size, loc, scale):
|
|
2109
|
+
out = float(loc) + np.random.exponential(scale=float(scale), size=size)
|
|
2110
|
+
if hasattr(sf, 'as_float64'):
|
|
2111
|
+
return sf.as_float64(out)
|
|
2112
|
+
return out
|
|
2113
|
+
|
|
2114
|
+
|
|
2115
|
+
def _rvs_cauchy(sf, *, size, loc, scale):
|
|
2116
|
+
u = np.random.random(size=size)
|
|
2117
|
+
out = float(loc) + float(scale) * np.tan(np.pi * (u - 0.5))
|
|
2118
|
+
if hasattr(sf, 'as_float64'):
|
|
2119
|
+
return sf.as_float64(out)
|
|
2120
|
+
return out
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
def _rvs_laplace(sf, *, size, loc, scale):
|
|
2124
|
+
out = np.random.laplace(loc=float(loc), scale=float(scale), size=size)
|
|
2125
|
+
if hasattr(sf, 'as_float64'):
|
|
2126
|
+
return sf.as_float64(out)
|
|
2127
|
+
return out
|
|
2128
|
+
|
|
2129
|
+
|
|
2130
|
+
def _rvs_logistic(sf, *, size, loc, scale):
|
|
2131
|
+
u = np.random.random(size=size)
|
|
2132
|
+
out = float(loc) + float(scale) * np.log(u / (1.0 - u))
|
|
2133
|
+
if hasattr(sf, 'as_float64'):
|
|
2134
|
+
return sf.as_float64(out)
|
|
2135
|
+
return out
|
|
2136
|
+
|
|
2137
|
+
|
|
2138
|
+
def _rvs_chi2(sf, *, df, size):
|
|
2139
|
+
out = np.random.chisquare(df=float(df), size=size)
|
|
2140
|
+
if hasattr(sf, 'as_float64'):
|
|
2141
|
+
return sf.as_float64(out)
|
|
2142
|
+
return out
|
|
2143
|
+
|
|
2144
|
+
|
|
2145
|
+
def _rvs_gamma(sf, *, a, size, loc, scale):
|
|
2146
|
+
out = float(loc) + np.random.gamma(shape=float(a), scale=float(scale), size=size)
|
|
2147
|
+
if hasattr(sf, 'as_float64'):
|
|
2148
|
+
return sf.as_float64(out)
|
|
2149
|
+
return out
|
|
2150
|
+
|
|
2151
|
+
|
|
2152
|
+
def _rvs_beta(sf, *, a, b, size, loc, scale):
|
|
2153
|
+
out = float(loc) + float(scale) * np.random.beta(float(a), float(b), size=size)
|
|
2154
|
+
if hasattr(sf, 'as_float64'):
|
|
2155
|
+
return sf.as_float64(out)
|
|
2156
|
+
return out
|
|
2157
|
+
|
|
2158
|
+
|
|
2159
|
+
def _rvs_f(sf, *, dfn, dfd, size):
|
|
2160
|
+
out = np.random.f(dfn=float(dfn), dfd=float(dfd), size=size)
|
|
2161
|
+
if hasattr(sf, 'as_float64'):
|
|
2162
|
+
return sf.as_float64(out)
|
|
2163
|
+
return out
|
|
2164
|
+
|
|
2165
|
+
|
|
2166
|
+
def _rvs_weibull(sf, *, c, size, loc, scale):
|
|
2167
|
+
out = float(loc) + float(scale) * np.random.weibull(a=float(c), size=size)
|
|
2168
|
+
if hasattr(sf, 'as_float64'):
|
|
2169
|
+
return sf.as_float64(out)
|
|
2170
|
+
return out
|
|
2171
|
+
|
|
2172
|
+
|
|
2173
|
+
def _rvs_lognorm(sf, *, s, size, loc, scale):
|
|
2174
|
+
out = float(loc) + float(scale) * np.exp(float(s) * np.random.normal(size=size))
|
|
2175
|
+
if hasattr(sf, 'as_float64'):
|
|
2176
|
+
return sf.as_float64(out)
|
|
2177
|
+
return out
|
|
2178
|
+
|
|
2179
|
+
|
|
2180
|
+
def _rvs_poisson(sf, *, mu, size, loc):
|
|
2181
|
+
out = np.random.poisson(lam=float(mu), size=size) + int(loc)
|
|
2182
|
+
if hasattr(sf, 'as_float64'):
|
|
2183
|
+
return sf.as_float64(out)
|
|
2184
|
+
return out
|
|
2185
|
+
|
|
2186
|
+
|
|
2187
|
+
def _rvs_binom(sf, *, n, p, size, loc):
|
|
2188
|
+
out = np.random.binomial(n=int(n), p=float(p), size=size) + int(loc)
|
|
2189
|
+
if hasattr(sf, 'as_float64'):
|
|
2190
|
+
return sf.as_float64(out)
|
|
2191
|
+
return out
|
|
2192
|
+
|
|
2193
|
+
|
|
2194
|
+
# =============================================================================
|
|
2195
|
+
# Factory
|
|
2196
|
+
# =============================================================================
|
|
2197
|
+
|
|
2198
|
+
_DISTRIBUTION_FACTORIES = {
|
|
2199
|
+
"norm": lambda sf: NormDistributionBase(sf),
|
|
2200
|
+
"t": lambda sf: TDistributionBase(sf),
|
|
2201
|
+
"uniform": lambda sf: UniformDistributionBase(sf),
|
|
2202
|
+
"expon": lambda sf: ExponDistributionBase(sf),
|
|
2203
|
+
"cauchy": lambda sf: CauchyDistributionBase(sf),
|
|
2204
|
+
"laplace": lambda sf: LaplaceDistributionBase(sf),
|
|
2205
|
+
"logistic": lambda sf: LogisticDistributionBase(sf),
|
|
2206
|
+
"chi2": lambda sf: Chi2DistributionBase(sf),
|
|
2207
|
+
"gamma": lambda sf: GammaDistributionBase(sf),
|
|
2208
|
+
"beta": lambda sf: BetaDistributionBase(sf),
|
|
2209
|
+
"f": lambda sf: FDistributionBase(sf),
|
|
2210
|
+
"weibull_min": lambda sf: WeibullMinDistributionBase(sf),
|
|
2211
|
+
"lognorm": lambda sf: LognormDistributionBase(sf, NormDistributionBase(sf)),
|
|
2212
|
+
"poisson": lambda sf: PoissonDistributionBase(sf),
|
|
2213
|
+
"binom": lambda sf: BinomDistributionBase(sf),
|
|
2214
|
+
}
|
|
2215
|
+
|
|
2216
|
+
_NATIVE_NAMES = sorted(_DISTRIBUTION_FACTORIES.keys())
|
|
2217
|
+
|
|
2218
|
+
|
|
2219
|
+
def _make_sf(backend: str, device: str | None = None, *, use_lut: bool = True) -> SpecialFunctions:
|
|
2220
|
+
"""Create a SpecialFunctions instance for the given backend name."""
|
|
2221
|
+
if backend == "numpy":
|
|
2222
|
+
return ScipySpecialFunctions(use_lut=use_lut)
|
|
2223
|
+
if backend == "cupy":
|
|
2224
|
+
return CuPySpecialFunctions(use_lut=use_lut)
|
|
2225
|
+
if backend == "torch":
|
|
2226
|
+
return TorchSpecialFunctions(device=device, use_lut=use_lut)
|
|
2227
|
+
raise ValueError(f"Unsupported backend: {backend}")
|
|
2228
|
+
|
|
2229
|
+
|
|
2230
|
+
def get_distribution(name: str, backend: str = "auto", device: str | None = None, *, use_lut: bool = True):
|
|
2231
|
+
"""Get a distribution object for the given backend.
|
|
2232
|
+
|
|
2233
|
+
Parameters
|
|
2234
|
+
----------
|
|
2235
|
+
name : str
|
|
2236
|
+
Distribution name (e.g. ``'norm'``, ``'t'``, ``'chi2'``).
|
|
2237
|
+
backend : {'auto', 'numpy', 'cupy', 'torch'}, default='auto'
|
|
2238
|
+
Which backend to use. ``'auto'`` picks the first available GPU
|
|
2239
|
+
backend (cupy > torch) or falls back to numpy.
|
|
2240
|
+
device : str, optional
|
|
2241
|
+
Torch device string (e.g. ``'cuda'``, ``'cuda:0'``, ``'cpu'``).
|
|
2242
|
+
Only used when backend is ``'torch'``.
|
|
2243
|
+
use_lut : bool, default=True
|
|
2244
|
+
Use LUT cache + 1-step Newton refinement for inverse special functions
|
|
2245
|
+
(``betaincinv``, ``gammaincinv``). When ``False``, falls back to the
|
|
2246
|
+
full iterative solver (scipy for numpy, Newton-Raphson for torch).
|
|
2247
|
+
``True`` gives 10-500x speedup for ``t.ppf``/``f.ppf`` on GPU,
|
|
2248
|
+
with negligible accuracy loss (LUT is built from scipy reference values).
|
|
2249
|
+
|
|
2250
|
+
Returns
|
|
2251
|
+
-------
|
|
2252
|
+
Distribution object with methods: cdf, sf, ppf, isf, pdf, rvs, etc.
|
|
2253
|
+
"""
|
|
2254
|
+
if backend == "auto":
|
|
2255
|
+
if CuPySpecialFunctions is not None: # always importable if cupy installed
|
|
2256
|
+
try:
|
|
2257
|
+
return get_distribution(name, backend="cupy", device=device, use_lut=use_lut)
|
|
2258
|
+
except Exception:
|
|
2259
|
+
pass
|
|
2260
|
+
try:
|
|
2261
|
+
return get_distribution(name, backend="torch", device=device, use_lut=use_lut)
|
|
2262
|
+
except Exception:
|
|
2263
|
+
pass
|
|
2264
|
+
backend = "numpy"
|
|
2265
|
+
|
|
2266
|
+
sf = _make_sf(backend, device, use_lut=use_lut)
|
|
2267
|
+
factory = _DISTRIBUTION_FACTORIES.get(name)
|
|
2268
|
+
if factory is None:
|
|
2269
|
+
# Try case-insensitive
|
|
2270
|
+
factory = _DISTRIBUTION_FACTORIES.get(name.lower())
|
|
2271
|
+
if factory is None:
|
|
2272
|
+
raise ValueError(f"Unknown distribution: {name}")
|
|
2273
|
+
return factory(sf)
|
|
2274
|
+
|
|
2275
|
+
|
|
2276
|
+
def list_available_distributions():
|
|
2277
|
+
"""List all natively implemented distribution names."""
|
|
2278
|
+
return list(_NATIVE_NAMES)
|
|
2279
|
+
|
|
2280
|
+
|
|
2281
|
+
# =============================================================================
|
|
2282
|
+
# DistributionProxy — module-level lazy singletons
|
|
2283
|
+
# =============================================================================
|
|
2284
|
+
|
|
2285
|
+
class DistributionProxy:
|
|
2286
|
+
"""Lazy proxy that resolves the backend on each call.
|
|
2287
|
+
|
|
2288
|
+
Supports ``backend=`` keyword override::
|
|
2289
|
+
|
|
2290
|
+
norm.cdf(x) # auto backend
|
|
2291
|
+
norm.cdf(x, backend="torch") # force torch
|
|
2292
|
+
"""
|
|
2293
|
+
|
|
2294
|
+
def __init__(self, name: str, default_backend: str = "auto", device: str | None = None, *, use_lut: bool = True):
|
|
2295
|
+
self._name = name
|
|
2296
|
+
self._default_backend = default_backend
|
|
2297
|
+
self._device = device
|
|
2298
|
+
self._use_lut = use_lut
|
|
2299
|
+
|
|
2300
|
+
def _resolve(self, kwargs, *arrays):
|
|
2301
|
+
from statgpu.backends import _is_torch_array, _resolve_backend
|
|
2302
|
+
|
|
2303
|
+
backend = kwargs.pop("backend", self._default_backend)
|
|
2304
|
+
device = kwargs.pop("device", self._device)
|
|
2305
|
+
use_lut = kwargs.pop("use_lut", self._use_lut)
|
|
2306
|
+
if backend == "auto":
|
|
2307
|
+
backend = _resolve_backend("auto", *arrays, *kwargs.values())
|
|
2308
|
+
if backend == "torch" and device is None:
|
|
2309
|
+
for arr in (*arrays, *kwargs.values()):
|
|
2310
|
+
if _is_torch_array(arr):
|
|
2311
|
+
device = str(arr.device)
|
|
2312
|
+
break
|
|
2313
|
+
return get_distribution(self._name, backend=backend, device=device, use_lut=use_lut)
|
|
2314
|
+
|
|
2315
|
+
def __repr__(self):
|
|
2316
|
+
return (f"DistributionProxy({self._name!r}, "
|
|
2317
|
+
f"backend={self._default_backend!r}, "
|
|
2318
|
+
f"use_lut={self._use_lut!r})")
|
|
2319
|
+
|
|
2320
|
+
def cdf(self, x, **kw):
|
|
2321
|
+
return self._resolve(kw, x).cdf(x, **kw)
|
|
2322
|
+
|
|
2323
|
+
def sf(self, x, **kw):
|
|
2324
|
+
return self._resolve(kw, x).sf(x, **kw)
|
|
2325
|
+
|
|
2326
|
+
def ppf(self, q, **kw):
|
|
2327
|
+
return self._resolve(kw, q).ppf(q, **kw)
|
|
2328
|
+
|
|
2329
|
+
def isf(self, q, **kw):
|
|
2330
|
+
return self._resolve(kw, q).isf(q, **kw)
|
|
2331
|
+
|
|
2332
|
+
def pdf(self, x, **kw):
|
|
2333
|
+
return self._resolve(kw, x).pdf(x, **kw)
|
|
2334
|
+
|
|
2335
|
+
def pmf(self, k, **kw):
|
|
2336
|
+
return self._resolve(kw, k).pmf(k, **kw)
|
|
2337
|
+
|
|
2338
|
+
def rvs(self, **kw):
|
|
2339
|
+
return self._resolve(kw, *kw.values()).rvs(**kw)
|
|
2340
|
+
|
|
2341
|
+
def two_sided_pvalue(self, stat_abs, **kw):
|
|
2342
|
+
return self._resolve(kw, stat_abs).two_sided_pvalue(stat_abs, **kw)
|
|
2343
|
+
|
|
2344
|
+
def two_sided_critical_value(self, alpha, **kw):
|
|
2345
|
+
return self._resolve(kw, alpha).two_sided_critical_value(alpha, **kw)
|
|
2346
|
+
|
|
2347
|
+
|
|
2348
|
+
# Module-level singletons (lazy, backend resolved per-call)
|
|
2349
|
+
norm = DistributionProxy("norm")
|
|
2350
|
+
t = DistributionProxy("t")
|
|
2351
|
+
uniform = DistributionProxy("uniform")
|
|
2352
|
+
expon = DistributionProxy("expon")
|
|
2353
|
+
cauchy = DistributionProxy("cauchy")
|
|
2354
|
+
laplace = DistributionProxy("laplace")
|
|
2355
|
+
logistic = DistributionProxy("logistic")
|
|
2356
|
+
chi2 = DistributionProxy("chi2")
|
|
2357
|
+
gamma = DistributionProxy("gamma")
|
|
2358
|
+
beta = DistributionProxy("beta")
|
|
2359
|
+
f = DistributionProxy("f")
|
|
2360
|
+
weibull_min = DistributionProxy("weibull_min")
|
|
2361
|
+
lognorm = DistributionProxy("lognorm")
|
|
2362
|
+
poisson = DistributionProxy("poisson")
|
|
2363
|
+
binom = DistributionProxy("binom")
|
|
2364
|
+
|
|
2365
|
+
|
|
2366
|
+
# Backward-compatible aliases (old CuPy-specific class names)
|
|
2367
|
+
NormDistributionGPU = NormDistributionBase
|
|
2368
|
+
TDistributionGPU = TDistributionBase
|
|
2369
|
+
UniformDistributionGPU = UniformDistributionBase
|
|
2370
|
+
ExponDistributionGPU = ExponDistributionBase
|
|
2371
|
+
CauchyDistributionGPU = CauchyDistributionBase
|
|
2372
|
+
LaplaceDistributionGPU = LaplaceDistributionBase
|
|
2373
|
+
LogisticDistributionGPU = LogisticDistributionBase
|
|
2374
|
+
Chi2DistributionGPU = Chi2DistributionBase
|
|
2375
|
+
GammaDistributionGPU = GammaDistributionBase
|
|
2376
|
+
BetaDistributionGPU = BetaDistributionBase
|
|
2377
|
+
FDistributionGPU = FDistributionBase
|
|
2378
|
+
WeibullMinDistributionGPU = WeibullMinDistributionBase
|
|
2379
|
+
LognormDistributionGPU = LognormDistributionBase
|
|
2380
|
+
PoissonDistributionGPU = PoissonDistributionBase
|
|
2381
|
+
BinomDistributionGPU = BinomDistributionBase
|
|
2382
|
+
|
|
2383
|
+
|
|
2384
|
+
def get_distribution_gpu(name: str, *, allow_fallback: bool = False):
|
|
2385
|
+
"""Backward-compatible wrapper: get GPU distribution by name.
|
|
2386
|
+
|
|
2387
|
+
Delegates to the unified factory, defaulting to the best GPU backend.
|
|
2388
|
+
"""
|
|
2389
|
+
import scipy.stats as sps
|
|
2390
|
+
|
|
2391
|
+
key = str(name).strip()
|
|
2392
|
+
if key.lower() in _DISTRIBUTION_FACTORIES:
|
|
2393
|
+
return get_distribution(key.lower(), backend="auto")
|
|
2394
|
+
|
|
2395
|
+
if allow_fallback:
|
|
2396
|
+
if hasattr(sps, key.lower()) or hasattr(sps, key):
|
|
2397
|
+
return ScipyFallbackDistribution(key.lower() if hasattr(sps, key.lower()) else key)
|
|
2398
|
+
|
|
2399
|
+
if hasattr(sps, key.lower()) or hasattr(sps, key):
|
|
2400
|
+
raise ValueError(
|
|
2401
|
+
f"Distribution '{name}' has no native GPU implementation. "
|
|
2402
|
+
"Set allow_fallback=True for explicit SciPy fallback."
|
|
2403
|
+
)
|
|
2404
|
+
raise ValueError(f"Unknown scipy.stats distribution: {name}")
|
|
2405
|
+
|
|
2406
|
+
|
|
2407
|
+
def list_available_distributions_gpu(include_scipy: bool = True):
|
|
2408
|
+
"""Backward-compatible: list available distribution names."""
|
|
2409
|
+
native = list_available_distributions()
|
|
2410
|
+
if not include_scipy:
|
|
2411
|
+
return native
|
|
2412
|
+
|
|
2413
|
+
import scipy.stats as sps
|
|
2414
|
+
from scipy.stats import rv_continuous, rv_discrete
|
|
2415
|
+
|
|
2416
|
+
scipy_names = []
|
|
2417
|
+
for n in dir(sps):
|
|
2418
|
+
if n.startswith("_"):
|
|
2419
|
+
continue
|
|
2420
|
+
try:
|
|
2421
|
+
obj = getattr(sps, n)
|
|
2422
|
+
except Exception:
|
|
2423
|
+
continue
|
|
2424
|
+
if isinstance(obj, (rv_continuous, rv_discrete)):
|
|
2425
|
+
scipy_names.append(n)
|
|
2426
|
+
return sorted(set(native + scipy_names))
|
|
2427
|
+
|
|
2428
|
+
|
|
2429
|
+
class ScipyFallbackDistribution:
|
|
2430
|
+
"""Dynamic scipy.stats distribution wrapper returning GPU-backed outputs."""
|
|
2431
|
+
|
|
2432
|
+
def __init__(self, name: str):
|
|
2433
|
+
self.name = str(name)
|
|
2434
|
+
|
|
2435
|
+
def __repr__(self):
|
|
2436
|
+
return f"ScipyFallbackDistribution('{self.name}')"
|
|
2437
|
+
|
|
2438
|
+
def _call(self, method_name, *args, **kwargs):
|
|
2439
|
+
import scipy.stats as sps
|
|
2440
|
+
dist = getattr(sps, self.name)
|
|
2441
|
+
method = getattr(dist, method_name)
|
|
2442
|
+
# Convert any GPU arrays to numpy for scipy
|
|
2443
|
+
np_args = []
|
|
2444
|
+
for v in args:
|
|
2445
|
+
if hasattr(v, "get"):
|
|
2446
|
+
np_args.append(v.get())
|
|
2447
|
+
elif hasattr(v, "detach"):
|
|
2448
|
+
np_args.append(v.detach().cpu().numpy())
|
|
2449
|
+
else:
|
|
2450
|
+
np_args.append(v)
|
|
2451
|
+
np_kw = {}
|
|
2452
|
+
for k, v in kwargs.items():
|
|
2453
|
+
if hasattr(v, "get"):
|
|
2454
|
+
np_kw[k] = v.get()
|
|
2455
|
+
elif hasattr(v, "detach"):
|
|
2456
|
+
np_kw[k] = v.detach().cpu().numpy()
|
|
2457
|
+
else:
|
|
2458
|
+
np_kw[k] = v
|
|
2459
|
+
result = method(*np_args, **np_kw)
|
|
2460
|
+
# Try to convert result back to GPU if default backend is GPU
|
|
2461
|
+
try:
|
|
2462
|
+
from statgpu.backends import get_backend
|
|
2463
|
+
backend = get_backend()
|
|
2464
|
+
if backend.name != "numpy":
|
|
2465
|
+
return backend.asarray(result)
|
|
2466
|
+
except Exception:
|
|
2467
|
+
pass
|
|
2468
|
+
return result
|
|
2469
|
+
|
|
2470
|
+
def cdf(self, x, *shape_args, **kwargs):
|
|
2471
|
+
return self._call("cdf", x, *shape_args, **kwargs)
|
|
2472
|
+
|
|
2473
|
+
def sf(self, x, *shape_args, **kwargs):
|
|
2474
|
+
return self._call("sf", x, *shape_args, **kwargs)
|
|
2475
|
+
|
|
2476
|
+
def ppf(self, q, *shape_args, **kwargs):
|
|
2477
|
+
return self._call("ppf", q, *shape_args, **kwargs)
|
|
2478
|
+
|
|
2479
|
+
def isf(self, q, *shape_args, **kwargs):
|
|
2480
|
+
return self._call("isf", q, *shape_args, **kwargs)
|
|
2481
|
+
|
|
2482
|
+
def pdf(self, x, *shape_args, **kwargs):
|
|
2483
|
+
return self._call("pdf", x, *shape_args, **kwargs)
|
|
2484
|
+
|
|
2485
|
+
def pmf(self, x, *shape_args, **kwargs):
|
|
2486
|
+
return self._call("pmf", x, *shape_args, **kwargs)
|
|
2487
|
+
|
|
2488
|
+
def rvs(self, *shape_args, size=None, dtype=None, **kwargs):
|
|
2489
|
+
out = self._call("rvs", *shape_args, size=size, **kwargs)
|
|
2490
|
+
if dtype is not None and hasattr(out, "astype"):
|
|
2491
|
+
out = out.astype(dtype)
|
|
2492
|
+
return out
|
|
2493
|
+
|
|
2494
|
+
|
|
2495
|
+
# =============================================================================
|
|
2496
|
+
# Backward-compatible special function aliases (for old consumers)
|
|
2497
|
+
# =============================================================================
|
|
2498
|
+
|
|
2499
|
+
def regularized_betainc_gpu(a, b, x):
|
|
2500
|
+
"""Backward-compatible alias: use get_distribution for new code."""
|
|
2501
|
+
sf = CuPySpecialFunctions()
|
|
2502
|
+
return sf.betainc(a, b, x)
|
|
2503
|
+
|
|
2504
|
+
|
|
2505
|
+
def regularized_betaincinv_gpu(a, b, y):
|
|
2506
|
+
"""Backward-compatible alias."""
|
|
2507
|
+
sf = CuPySpecialFunctions()
|
|
2508
|
+
return sf.betaincinv(a, b, y)
|
|
2509
|
+
|
|
2510
|
+
|
|
2511
|
+
def gammainc_gpu(a, x):
|
|
2512
|
+
"""Backward-compatible alias."""
|
|
2513
|
+
sf = CuPySpecialFunctions()
|
|
2514
|
+
return sf.gammainc(a, x)
|
|
2515
|
+
|
|
2516
|
+
|
|
2517
|
+
def gammaincc_gpu(a, x):
|
|
2518
|
+
"""Backward-compatible alias."""
|
|
2519
|
+
sf = CuPySpecialFunctions()
|
|
2520
|
+
return sf.gammaincc(a, x)
|
|
2521
|
+
|
|
2522
|
+
|
|
2523
|
+
def gammaincinv_gpu(a, q):
|
|
2524
|
+
"""Backward-compatible alias."""
|
|
2525
|
+
sf = CuPySpecialFunctions()
|
|
2526
|
+
return sf.gammaincinv(a, q)
|
|
2527
|
+
|
|
2528
|
+
|
|
2529
|
+
def gammaln_gpu(x):
|
|
2530
|
+
"""Backward-compatible alias."""
|
|
2531
|
+
sf = CuPySpecialFunctions()
|
|
2532
|
+
return sf.gammaln(x)
|
|
2533
|
+
|
|
2534
|
+
|
|
2535
|
+
# =============================================================================
|
|
2536
|
+
# Legacy distribution-function names (R-style)
|
|
2537
|
+
# =============================================================================
|
|
2538
|
+
|
|
2539
|
+
_LEGACY_DISTRIBUTION_FUNCTION_NAMES = {
|
|
2540
|
+
"t_cdf_gpu", "t_sf_gpu", "t_ppf_gpu", "t_two_sided_pvalue_gpu",
|
|
2541
|
+
"t_two_sided_critical_value_gpu", "norm_cdf_gpu", "norm_sf_gpu",
|
|
2542
|
+
"norm_ppf_gpu", "norm_isf_gpu", "norm_two_sided_pvalue_gpu",
|
|
2543
|
+
"norm_two_sided_critical_value_gpu", "rnorm_gpu", "dnorm_gpu",
|
|
2544
|
+
"dt_gpu", "rt_gpu", "pnorm_gpu", "qnorm_gpu", "pt_gpu", "qt_gpu",
|
|
2545
|
+
"dchisq_gpu", "pchisq_gpu", "qchisq_gpu", "rchisq_gpu",
|
|
2546
|
+
"dgamma_gpu", "pgamma_gpu", "qgamma_gpu", "rgamma_gpu",
|
|
2547
|
+
"dbeta_gpu", "pbeta_gpu", "qbeta_gpu", "rbeta_gpu",
|
|
2548
|
+
"df_gpu", "pf_gpu", "qf_gpu", "rf_gpu",
|
|
2549
|
+
"dpois_gpu", "ppois_gpu", "qpois_gpu", "rpois_gpu",
|
|
2550
|
+
"dbinom_gpu", "pbinom_gpu", "qbinom_gpu", "rbinom_gpu",
|
|
2551
|
+
}
|
|
2552
|
+
|
|
2553
|
+
|
|
2554
|
+
def __getattr__(name):
|
|
2555
|
+
"""Lazy access to legacy distribution functions."""
|
|
2556
|
+
if name.startswith("_"):
|
|
2557
|
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
|
2558
|
+
if name in _LEGACY_DISTRIBUTION_FUNCTION_NAMES:
|
|
2559
|
+
from statgpu.linear_model.legacy import _distributions_legacy_gpu as legacy
|
|
2560
|
+
return getattr(legacy, name)
|
|
2561
|
+
try:
|
|
2562
|
+
return get_distribution_gpu(name)
|
|
2563
|
+
except Exception as exc:
|
|
2564
|
+
raise AttributeError(f"module {__name__} has no attribute {name}") from exc
|
|
2565
|
+
|
|
2566
|
+
|
|
2567
|
+
# =============================================================================
|
|
2568
|
+
# Exports
|
|
2569
|
+
# =============================================================================
|
|
2570
|
+
|
|
2571
|
+
__all__ = [
|
|
2572
|
+
# Core
|
|
2573
|
+
"get_distribution",
|
|
2574
|
+
"list_available_distributions",
|
|
2575
|
+
"DistributionProxy",
|
|
2576
|
+
"SpecialFunctions",
|
|
2577
|
+
# Backends
|
|
2578
|
+
"CuPySpecialFunctions",
|
|
2579
|
+
"TorchSpecialFunctions",
|
|
2580
|
+
"ScipySpecialFunctions",
|
|
2581
|
+
# Distributions
|
|
2582
|
+
"NormDistributionBase",
|
|
2583
|
+
"TDistributionBase",
|
|
2584
|
+
"UniformDistributionBase",
|
|
2585
|
+
"ExponDistributionBase",
|
|
2586
|
+
"CauchyDistributionBase",
|
|
2587
|
+
"LaplaceDistributionBase",
|
|
2588
|
+
"LogisticDistributionBase",
|
|
2589
|
+
"Chi2DistributionBase",
|
|
2590
|
+
"GammaDistributionBase",
|
|
2591
|
+
"BetaDistributionBase",
|
|
2592
|
+
"FDistributionBase",
|
|
2593
|
+
"WeibullMinDistributionBase",
|
|
2594
|
+
"LognormDistributionBase",
|
|
2595
|
+
"PoissonDistributionBase",
|
|
2596
|
+
"BinomDistributionBase",
|
|
2597
|
+
# Module-level proxies
|
|
2598
|
+
"norm", "t", "uniform", "expon", "cauchy", "laplace",
|
|
2599
|
+
"logistic", "chi2", "gamma", "beta", "f",
|
|
2600
|
+
"weibull_min", "lognorm", "poisson", "binom",
|
|
2601
|
+
# Backward compat
|
|
2602
|
+
"NormDistributionGPU", "TDistributionGPU", "UniformDistributionGPU",
|
|
2603
|
+
"ExponDistributionGPU", "CauchyDistributionGPU", "LaplaceDistributionGPU",
|
|
2604
|
+
"LogisticDistributionGPU", "Chi2DistributionGPU", "GammaDistributionGPU",
|
|
2605
|
+
"BetaDistributionGPU", "FDistributionGPU", "WeibullMinDistributionGPU",
|
|
2606
|
+
"LognormDistributionGPU", "PoissonDistributionGPU", "BinomDistributionGPU",
|
|
2607
|
+
"ScipyFallbackDistribution",
|
|
2608
|
+
"get_distribution_gpu",
|
|
2609
|
+
"list_available_distributions_gpu",
|
|
2610
|
+
]
|