statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2610 @@
1
+ """
2
+ Unified probability-distribution backend.
3
+
4
+ Supports ``numpy``, ``cupy``, and ``torch`` backends through a single
5
+ ``SpecialFunctions`` protocol, eliminating code duplication across
6
+ ``_distributions_gpu.py`` and ``_distributions_torch.py``.
7
+
8
+ Usage::
9
+
10
+ from statgpu.inference._distributions_backend import get_distribution, norm, t
11
+
12
+ # Explicit backend
13
+ norm_dist = get_distribution("norm", backend="numpy")
14
+ norm_dist.cdf([0.0, 1.0, 2.0])
15
+
16
+ # Module-level proxy with auto backend detection
17
+ norm.cdf([0.0, 1.0, 2.0])
18
+ t.cdf(1.5, df=10, backend="cupy")
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import math
24
+ from abc import abstractmethod
25
+ from functools import lru_cache
26
+ from typing import Any, Protocol, runtime_checkable
27
+
28
+ import numpy as np
29
+
30
+ from statgpu.backends import _get_torch_device_str as _get_torch_device
31
+
32
+
33
+ # =============================================================================
34
+ # SpecialFunctions protocol — abstracts away library-specific special functions
35
+ # =============================================================================
36
+
37
+ @runtime_checkable
38
+ class SpecialFunctions(Protocol):
39
+ """Protocol for special-function providers.
40
+
41
+ Implementations: ``CuPySpecialFunctions``, ``TorchSpecialFunctions``,
42
+ ``ScipySpecialFunctions``.
43
+ """
44
+
45
+ @abstractmethod
46
+ def betainc(self, a, b, x):
47
+ """Regularized incomplete beta I_x(a, b)."""
48
+
49
+ @abstractmethod
50
+ def betaincinv(self, a, b, y):
51
+ """Inverse regularized incomplete beta."""
52
+
53
+ @abstractmethod
54
+ def gammainc(self, a, x):
55
+ """Regularized lower incomplete gamma P(a, x)."""
56
+
57
+ @abstractmethod
58
+ def gammaincc(self, a, x):
59
+ """Regularized upper incomplete gamma Q(a, x)."""
60
+
61
+ @abstractmethod
62
+ def gammaincinv(self, a, q):
63
+ """Inverse regularized lower incomplete gamma."""
64
+
65
+ @abstractmethod
66
+ def gammaln(self, x):
67
+ """Log-gamma."""
68
+
69
+ @abstractmethod
70
+ def erf(self, x):
71
+ """Error function."""
72
+
73
+ @abstractmethod
74
+ def erfc(self, x):
75
+ """Complementary error function."""
76
+
77
+ @abstractmethod
78
+ def erfcinv(self, y):
79
+ """Inverse complementary error function."""
80
+
81
+
82
+ # =============================================================================
83
+ # CuPy backend
84
+ # =============================================================================
85
+
86
+ class CuPySpecialFunctions:
87
+ """Special functions via cupyx.scipy.special with LUT acceleration.
88
+
89
+ Inverse special functions (betaincinv, gammaincinv) use GPU-resident LUT
90
+ + 1-step Newton refinement for ~10-100x speedup over cupyx iterative solver.
91
+ """
92
+
93
+ def __init__(self, *, use_lut: bool = True):
94
+ import cupy as cp
95
+ import cupyx.scipy.special as csp
96
+ self._cp = cp
97
+ self._csp = csp
98
+ self.use_lut = use_lut
99
+ # LUT caches for inverse special functions (instance-level)
100
+ self._betaincinv_lut = {}
101
+ self._gammaincinv_lut = {}
102
+
103
+ def betainc(self, a, b, x):
104
+ return self._csp.betainc(a, b, self._cp.asarray(x, dtype=self._cp.float64))
105
+
106
+ def betaincinv(self, a, b, y):
107
+ cp = self._cp
108
+ yt = cp.asarray(y, dtype=cp.float64)
109
+ try:
110
+ af, bf = float(a), float(b)
111
+ except (TypeError, ValueError):
112
+ return self._csp.betaincinv(a, b, yt)
113
+ if not self.use_lut:
114
+ return self._csp.betaincinv(a, b, yt)
115
+ if af < 0.3 or bf < 0.3 or af > 50 or bf > 50 or abs(af - bf) > 30:
116
+ return self._csp.betaincinv(a, b, yt)
117
+ key = (af, bf)
118
+ if key not in self._betaincinv_lut:
119
+ x_grid, y_grid = self._build_betaincinv_lut(af, bf, 20000)
120
+ self._betaincinv_lut[key] = (cp.asarray(x_grid), cp.asarray(y_grid))
121
+ yg, xg = self._betaincinv_lut[key]
122
+ idx = cp.searchsorted(yg, cp.clip(yt, 1e-15, 1.0 - 1e-15)).clip(1, len(yg) - 1)
123
+ y0, y1 = yg[idx - 1], yg[idx]
124
+ x0, x1 = xg[idx - 1], xg[idx]
125
+ w = (yt - y0) / (y1 - y0 + 1e-300)
126
+ x = cp.clip(x0 + w * (x1 - x0), 1e-10, 1.0 - 1e-10)
127
+ # 1-step Newton refine using cupyx betainc
128
+ import math as _math
129
+ log_beta = _math.lgamma(af) + _math.lgamma(bf) - _math.lgamma(af + bf)
130
+ p = self._csp.betainc(af, bf, x)
131
+ diff = p - yt
132
+ log_deriv = (af - 1.0) * cp.log(cp.clip(x, 1e-300, None)) + \
133
+ (bf - 1.0) * cp.log(cp.clip(1.0 - x, 1e-300, None)) - log_beta
134
+ deriv = cp.exp(log_deriv)
135
+ x1 = x - diff / cp.clip(deriv, 1e-300, 1e300)
136
+ return cp.clip(x1, 1e-15, 1.0 - 1e-15)
137
+
138
+ @staticmethod
139
+ def _build_betaincinv_lut(a, b, n_grid):
140
+ """Build LUT via scipy on CPU, returns (x_grid, y_grid) as numpy arrays.
141
+
142
+ Uses log spacing near both boundaries for better precision when
143
+ a or b is small (e.g. b=0.5 for t/f distributions).
144
+ """
145
+ import scipy.special as _scsp
146
+ eps = 1e-15
147
+ n_edge = int(n_grid * 0.4)
148
+ n_mid = n_grid - 2 * n_edge
149
+ x_lo = np.logspace(np.log10(eps), np.log10(0.01), n_edge)
150
+ x_mid = np.linspace(0.01, 0.99, n_mid + 2)[1:-1]
151
+ x_hi = 1.0 - np.logspace(np.log10(eps), np.log10(0.01), n_edge)[::-1]
152
+ x_grid = np.concatenate([x_lo, x_mid, x_hi])
153
+ if len(x_grid) > n_grid:
154
+ x_grid = x_grid[:n_grid]
155
+ y_grid = _scsp.betainc(a, b, x_grid)
156
+ return x_grid, y_grid
157
+
158
+ def gammainc(self, a, x):
159
+ return self._csp.gammainc(
160
+ self._cp.asarray(a, dtype=self._cp.float64),
161
+ self._cp.asarray(x, dtype=self._cp.float64),
162
+ )
163
+
164
+ def gammaincc(self, a, x):
165
+ return self._csp.gammaincc(
166
+ self._cp.asarray(a, dtype=self._cp.float64),
167
+ self._cp.asarray(x, dtype=self._cp.float64),
168
+ )
169
+
170
+ def gammaincinv(self, a, q):
171
+ cp = self._cp
172
+ qt = cp.asarray(q, dtype=cp.float64)
173
+ try:
174
+ af = float(a)
175
+ except (TypeError, ValueError):
176
+ return self._csp.gammaincinv(cp.asarray(a, dtype=cp.float64), qt)
177
+ if not self.use_lut:
178
+ return self._csp.gammaincinv(cp.asarray(a, dtype=cp.float64), qt)
179
+ if af < 1.0:
180
+ return self._csp.gammaincinv(cp.asarray(a, dtype=cp.float64), qt)
181
+ key = (af,)
182
+ if key not in self._gammaincinv_lut:
183
+ x_grid, y_grid = self._build_gammaincinv_lut(af, 20000)
184
+ self._gammaincinv_lut[key] = (cp.asarray(x_grid), cp.asarray(y_grid))
185
+ yg, xg = self._gammaincinv_lut[key]
186
+ idx = cp.searchsorted(yg, cp.clip(qt, 1e-15, 1.0 - 1e-15)).clip(1, len(yg) - 1)
187
+ y0, y1 = yg[idx - 1], yg[idx]
188
+ x0, x1 = xg[idx - 1], xg[idx]
189
+ w = (qt - y0) / (y1 - y0 + 1e-300)
190
+ x = cp.clip(x0 + w * (x1 - x0), 1e-15, 1e6)
191
+ # 1-step Newton refine using cupyx gammainc
192
+ import math as _math
193
+ log_ga = _math.lgamma(af)
194
+ p = self._csp.gammainc(af, x)
195
+ diff = p - qt
196
+ log_deriv = (af - 1.0) * cp.log(cp.clip(x, 1e-300, None)) - x - log_ga
197
+ deriv = cp.exp(log_deriv)
198
+ x1 = x - diff / cp.clip(deriv, 1e-300, 1e300)
199
+ return cp.clip(x1, 1e-15, 1e6)
200
+
201
+ @staticmethod
202
+ def _build_gammaincinv_lut(a, n_grid):
203
+ """Build LUT via scipy on CPU, returns (x_grid, y_grid) as numpy arrays."""
204
+ import math
205
+ import scipy.special as _scsp
206
+ x_max = a + 20 * math.sqrt(max(a, 0.1)) + 10
207
+ x_max = min(x_max, 1e6)
208
+ n_log = n_grid // 3
209
+ n_lin = n_grid - n_log
210
+ x_lo = np.logspace(-15, math.log10(max(x_max, 1e-10)), n_log, endpoint=False)
211
+ x_hi = np.linspace(x_lo[-1] if len(x_lo) > 0 else 0, x_max, n_lin + 1)[1:]
212
+ x_grid = np.concatenate([x_lo, x_hi])
213
+ if len(x_grid) < n_grid:
214
+ extra = np.linspace(x_grid[-1], x_max, n_grid - len(x_grid) + 2)[1:]
215
+ x_grid = np.concatenate([x_grid, extra])
216
+ x_grid = x_grid[:n_grid]
217
+ y_grid = _scsp.gammainc(a, x_grid)
218
+ y_grid[0] = 0.0
219
+ y_grid[-1] = 1.0
220
+ return x_grid, y_grid
221
+
222
+ def gammaln(self, x):
223
+ return self._csp.gammaln(self._cp.asarray(x, dtype=self._cp.float64))
224
+
225
+ def erf(self, x):
226
+ return self._csp.erf(self._cp.asarray(x, dtype=self._cp.float64))
227
+
228
+ def erfc(self, x):
229
+ return self._csp.erfc(self._cp.asarray(x, dtype=self._cp.float64))
230
+
231
+ def erfcinv(self, y):
232
+ return self._csp.erfcinv(self._cp.asarray(y, dtype=self._cp.float64))
233
+
234
+ def sqrt(self, x):
235
+ return self._cp.sqrt(self._cp.asarray(x, dtype=self._cp.float64))
236
+
237
+ @property
238
+ def pi(self):
239
+ return self._cp.pi
240
+
241
+ def clip(self, x, lo, hi):
242
+ return self._cp.clip(x, lo, hi)
243
+
244
+ def where(self, cond, x, y):
245
+ return self._cp.where(cond, x, y)
246
+
247
+ def as_float64(self, x):
248
+ return self._cp.asarray(x, dtype=self._cp.float64)
249
+
250
+
251
+ # =============================================================================
252
+ # Torch backend
253
+ # =============================================================================
254
+
255
+ # Module-level cache for torch betaincinv inverse LUTs (scalar a, b)
256
+ # Key: (a, b, device) -> (y_grid, x_grid) tensors on device
257
+ _torch_betaincinv_lut_cache: dict = {}
258
+
259
+
260
+ # Module-level cache for torch betainc forward LUTs (scalar a, b)
261
+ # Key: (a, b, device) -> (x_grid, y_grid) tensors on device
262
+ _torch_betainc_lut_cache: dict = {}
263
+
264
+
265
+ def _get_torch_betaincinv_lut(a, b, device, n_points=20000):
266
+ """Build a GPU-resident inverse LUT for betaincinv(a, b, y).
267
+
268
+ Precomputes x = betaincinv(a, b, y) for 20K y values via scipy on CPU
269
+ (one-time cost, <200ms) then uses searchsorted for O(log n) lookup.
270
+ """
271
+ from scipy import special as _scsp
272
+ import torch
273
+
274
+ cache_key = (a, b, device)
275
+ if cache_key in _torch_betaincinv_lut_cache:
276
+ return _torch_betaincinv_lut_cache[cache_key]
277
+
278
+ y_vals = np.linspace(1e-15, 1.0 - 1e-15, n_points)
279
+ x_vals = _scsp.betaincinv(a, b, y_vals)
280
+ y_grid = torch.as_tensor(y_vals, dtype=torch.float64, device=device)
281
+ x_grid = torch.as_tensor(x_vals, dtype=torch.float64, device=device)
282
+ _torch_betaincinv_lut_cache[cache_key] = (y_grid, x_grid)
283
+ return y_grid, x_grid
284
+
285
+
286
+ def _get_torch_betainc_lut(a, b, device, n_points=40000):
287
+ """Build a GPU-resident forward LUT for betainc(a, b, x).
288
+
289
+ Precomputes y = betainc(a, b, x) for 40K x values via scipy on CPU
290
+ (one-time cost, <50ms) then uses searchsorted for O(log n) lookup.
291
+ Uses log spacing near boundaries for better precision when a or b is small.
292
+ """
293
+ from scipy import special as _scsp
294
+ import torch
295
+
296
+ cache_key = (a, b, device)
297
+ if cache_key in _torch_betainc_lut_cache:
298
+ return _torch_betainc_lut_cache[cache_key]
299
+
300
+ # Log spacing near boundaries for b < 1 singularity
301
+ eps = 1e-15
302
+ n_edge = int(n_points * 0.4)
303
+ n_mid = n_points - 2 * n_edge
304
+ x_lo = np.logspace(np.log10(eps), np.log10(0.01), n_edge)
305
+ x_mid = np.linspace(0.01, 0.99, n_mid + 2)[1:-1]
306
+ x_hi = 1.0 - np.logspace(np.log10(eps), np.log10(0.01), n_edge)[::-1]
307
+ x_vals = np.concatenate([x_lo, x_mid, x_hi])[:n_points]
308
+ y_vals = _scsp.betainc(a, b, x_vals)
309
+ x_grid = torch.as_tensor(x_vals, dtype=torch.float64, device=device)
310
+ y_grid = torch.as_tensor(y_vals, dtype=torch.float64, device=device)
311
+ _torch_betainc_lut_cache[cache_key] = (x_grid, y_grid)
312
+ return x_grid, y_grid
313
+
314
+
315
+
316
+ class TorchSpecialFunctions:
317
+ """Special functions via torch.special with fallbacks for missing functions."""
318
+
319
+ def __init__(self, device: str | None = None, *, use_lut: bool = True):
320
+ import torch
321
+ self._torch = torch
322
+ self._device = device or _get_torch_device()
323
+ self.use_lut = use_lut
324
+
325
+ def _as_tensor(self, x):
326
+ return self._torch.as_tensor(x, dtype=self._torch.float64, device=self._device)
327
+
328
+ # ── betainc fallback ───────────────────────────────────────────
329
+ def betainc(self, a, b, x):
330
+ t = self._torch
331
+ # Check if torch has native betainc (>= 1.8)
332
+ if hasattr(t.special, "betainc"):
333
+ return t.special.betainc(
334
+ self._as_tensor(a), self._as_tensor(b), self._as_tensor(x),
335
+ )
336
+ # LUT-based betainc for scalar a, b (major speedup for binom)
337
+ try:
338
+ af, bf = float(a), float(b)
339
+ except (TypeError, ValueError):
340
+ pass # fall through to element-wise loop below
341
+ else:
342
+ if self.use_lut:
343
+ try:
344
+ xg, yg = _get_torch_betainc_lut(af, bf, self._device)
345
+ xt = self._as_tensor(x)
346
+ xt_clamp = t.clamp(xt, 0.0, 1.0)
347
+ idx = t.searchsorted(xg, xt_clamp).clip(1, len(xg) - 1)
348
+ x0, x1 = xg[idx - 1], xg[idx]
349
+ y0, y1 = yg[idx - 1], yg[idx]
350
+ w = (xt_clamp - x0) / (x1 - x0 + 1e-300)
351
+ return (y0 + w * (y1 - y0)).clamp(0.0, 1.0).view_as(xt)
352
+ except Exception:
353
+ pass # fall through to integral fallback
354
+ return self._betainc_integral(af, bf, self._as_tensor(x))
355
+ # Non-scalar a or b — grouped LUT lookup (avoids element-wise Chebyshev integral)
356
+ try:
357
+ return self._betainc_batch(a, b, x)
358
+ except Exception:
359
+ # Full fallback: compute on CPU via scipy
360
+ try:
361
+ import scipy.special as _scsp
362
+ a_np = np.asarray(self._as_tensor(a).cpu().numpy())
363
+ b_np = np.asarray(self._as_tensor(b).cpu().numpy())
364
+ x_np = np.asarray(self._as_tensor(x).cpu().numpy())
365
+ result = _scsp.betainc(
366
+ np.clip(a_np, 1, None).astype(int),
367
+ np.clip(b_np, 1, None).astype(int),
368
+ np.clip(x_np, 0.0, 1.0),
369
+ )
370
+ return self._as_tensor(result)
371
+ except Exception:
372
+ return self._betainc_integral(1, 1, self._as_tensor(x))
373
+
374
+ def _betainc_integral(self, a, b, x):
375
+ """Regularized incomplete beta via trapezoidal rule on Chebyshev-mapped grid.
376
+
377
+ Uses Chebyshev-node mapping to cluster grid points near s=0 and s=1.
378
+ """
379
+ import math as _math
380
+ t = self._torch
381
+ device = x.device
382
+ x = t.clamp(x, 0.0, 1.0)
383
+ af, bf = float(a), float(b)
384
+ if af < 1.0 or bf < 1.0:
385
+ n_grid = 64000
386
+ elif af < 5.0 or bf < 5.0:
387
+ n_grid = 16000
388
+ else:
389
+ n_grid = 8000
390
+ theta = t.linspace(0, _math.pi, n_grid, device=device, dtype=t.float64)
391
+ s = 0.5 * (1.0 + t.cos(theta)) # descending [≈1, 0]
392
+ s = s.flip(0) # ascending [0, ≈1]
393
+ eps = 1e-14
394
+ log_val = (a - 1) * t.log(s + 1e-300) + (b - 1) * t.log1p(-s + 1e-300)
395
+ log_val = t.where(t.isfinite(log_val), log_val, t.tensor(-700.0, dtype=t.float64, device=device))
396
+ f = t.exp(log_val)
397
+ beta_ab = _math.exp(_math.lgamma(af) + _math.lgamma(bf) - _math.lgamma(af + bf))
398
+ ds = s[1:] - s[:-1]
399
+ cum = t.zeros(n_grid, device=device, dtype=t.float64)
400
+ cum[1:] = t.cumsum((f[:-1] + f[1:]) * 0.5 * ds, dim=0)
401
+ x_flat = x.flatten()
402
+ idx = t.searchsorted(s, x_flat, right=True).clamp(1, n_grid - 1)
403
+ frac = (x_flat - s[idx - 1]) / (s[idx] - s[idx - 1] + 1e-300)
404
+ frac = frac.clamp(0.0, 1.0)
405
+ result = cum[idx - 1] + frac * (cum[idx] - cum[idx - 1])
406
+ result = result / beta_ab
407
+ result = t.clamp(result, 0.0, 1.0)
408
+ result = t.where(x_flat <= eps, 0.0, result)
409
+ result = t.where(x_flat >= 1 - eps, 1.0, result)
410
+ return result.view_as(x)
411
+
412
+ def _betainc_batch(self, a, b, x):
413
+ """Batch betainc for non-scalar a, b via fused 2D-LUT interpolation.
414
+
415
+ All LUTs share the same x-grid (fixed log-spaced scheme), so we:
416
+ 1. Build a 2D y-grid of shape (n_pairs, n_grid) for all unique (a,b) pairs
417
+ 2. Call searchsorted ONCE to find the bracket index for all elements
418
+ 3. Interpolate all pairs simultaneously via batched gather
419
+ 4. Scatter results back to output positions
420
+
421
+ This avoids 100+ separate searchsorted calls, reducing overhead by ~100x.
422
+ """
423
+ x_flat = self._as_tensor(x).flatten()
424
+ a_flat = self._as_tensor(a).flatten()
425
+ b_flat = self._as_tensor(b).flatten()
426
+ t = self._torch
427
+
428
+ # Clamp to >= 1 for key encoding (edge cases get overwritten by caller)
429
+ ai = t.clamp(t.round(a_flat).long(), 1, 100000)
430
+ bi = t.clamp(t.round(b_flat).long(), 1, 100000)
431
+ # Encode as single key for unique computation
432
+ keys = ai * 100000 + bi
433
+ unique_keys, inverse_idx = t.unique(keys, return_inverse=True)
434
+
435
+ n_pairs = unique_keys.numel()
436
+ n_elem = x_flat.numel()
437
+
438
+ # Build 2D grid: (n_pairs, n_grid)
439
+ # All LUTs share the same x-grid, so we only need one
440
+ y_grid = t.zeros((n_pairs, 40000), dtype=t.float64, device=self._device)
441
+ xg = None
442
+ n_actual = 0
443
+ failed_pairs = []
444
+ for pi in range(n_pairs):
445
+ k_val = unique_keys[pi].item()
446
+ a_val = k_val // 100000
447
+ b_val = k_val - a_val * 100000
448
+ try:
449
+ xg_i, yg_i = _get_torch_betainc_lut(a_val, b_val, self._device)
450
+ if xg is None:
451
+ xg = xg_i # all LUTs share the same x-grid
452
+ n_actual = len(xg_i)
453
+ y_grid[pi, :len(yg_i)] = yg_i
454
+ except Exception:
455
+ failed_pairs.append((pi, float(a_val), float(b_val)))
456
+
457
+ if xg is None:
458
+ # All LUTs failed, fall back
459
+ return self._betainc_integral(1, 1, x_flat)
460
+
461
+ xg = xg[:n_actual]
462
+ y_grid = y_grid[:, :n_actual]
463
+
464
+ # Single searchsorted for all elements
465
+ x_clamp = t.clamp(x_flat, 0.0, 1.0)
466
+ sidx = t.searchsorted(xg, x_clamp).clip(1, n_actual - 1)
467
+
468
+ # Interpolation weights (same for all pairs)
469
+ x0g, x1g = xg[sidx - 1], xg[sidx]
470
+ w = (x_clamp - x0g) / (x1g - x0g + 1e-300)
471
+
472
+ # Gather y0/y1 for all pairs simultaneously: (n_pairs, n_elem)
473
+ y0_all = y_grid[:, sidx - 1] # (n_pairs, n_elem)
474
+ y1_all = y_grid[:, sidx] # (n_pairs, n_elem)
475
+ y_all = y0_all + w.unsqueeze(0) * (y1_all - y0_all) # (n_pairs, n_elem)
476
+ y_all = y_all.clamp(0.0, 1.0)
477
+
478
+ # Scatter: select the right pair index for each element
479
+ # inverse_idx: (n_elem,) → indices into pair dimension
480
+ # y_all: (n_pairs, n_elem) → gather along dim=0
481
+ result = y_all[inverse_idx, t.arange(n_elem, device=self._device)]
482
+ if failed_pairs:
483
+ for pi, a_val, b_val in failed_pairs:
484
+ mask = inverse_idx == pi
485
+ if t.any(mask):
486
+ result[mask] = self._betainc_integral(a_val, b_val, x_clamp[mask])
487
+
488
+ return result.view(self._as_tensor(a).shape)
489
+
490
+ def betaincinv(self, a, b, y):
491
+ t = self._torch
492
+ af, bf = float(a), float(b)
493
+ yt = self._as_tensor(y)
494
+ if hasattr(t.special, "betaincinv"):
495
+ return t.special.betaincinv(
496
+ self._as_tensor(a), self._as_tensor(b), yt,
497
+ )
498
+ # For scalar a, b: LUT lookup + 1-step Newton refine
499
+ if not self.use_lut:
500
+ return self._betaincinv_newton(af, bf, yt)
501
+ try:
502
+ y_grid, x_grid = _get_torch_betaincinv_lut(af, bf, yt.device)
503
+ # Searchsorted to find bracket index
504
+ idx = t.searchsorted(y_grid, t.clamp(yt, 0.0, 1.0)).clamp(1, len(y_grid) - 1)
505
+ # Interpolate between two nearest LUT points
506
+ y0, y1 = y_grid[idx - 1], y_grid[idx]
507
+ x0, x1 = x_grid[idx - 1], x_grid[idx]
508
+ w = (yt - y0) / (y1 - y0 + 1e-300)
509
+ x = x0 + w * (x1 - x0)
510
+ x = t.clamp(x, 1e-10, 1.0 - 1e-10)
511
+ # 1-step Newton refine
512
+ import math as _math
513
+ beta_ab = math.exp(math.lgamma(af) + math.lgamma(bf) - math.lgamma(af + bf))
514
+ val = self._betainc_integral(af, bf, x)
515
+ deriv = t.pow(t.clamp(x, 1e-300, 1 - 1e-300), af - 1) * \
516
+ t.pow(t.clamp(1 - x, 1e-300, 1 - 1e-300), bf - 1) / beta_ab
517
+ deriv = t.clamp(deriv, 1e-300, 1e300)
518
+ step = (val - yt) / deriv
519
+ x = x - step
520
+ x = t.clamp(x, 1e-10, 1.0 - 1e-10)
521
+ return x
522
+ except Exception:
523
+ return self._betaincinv_newton(af, bf, yt)
524
+
525
+ def _betaincinv_newton(self, a, b, y):
526
+ """Inverse regularized incomplete beta via damped Newton-Raphson."""
527
+ t = self._torch
528
+ device = y.device
529
+ y = t.clamp(y, 1e-15, 1 - 1e-15)
530
+ import math as _math
531
+ # Logit-normal approximation for initial guess
532
+ import scipy.special as _scsp
533
+ mu = _scsp.digamma(a) - _scsp.digamma(b)
534
+ sigma2 = 1.0 / a + 1.0 / b
535
+ sigma = math.sqrt(sigma2)
536
+ z = -_math.sqrt(2.0) * self.erfcinv(2.0 * y)
537
+ z = self._as_tensor(z) if not isinstance(z, t.Tensor) else z
538
+ logit_q = mu + sigma * z
539
+ x = 1.0 / (1.0 + t.exp(-logit_q))
540
+ x = t.clamp(x, 1e-10, 1.0 - 1e-10)
541
+ # Damped Newton refinement
542
+ beta_ab = math.exp(math.lgamma(a) + math.lgamma(b) - math.lgamma(a + b))
543
+ for _ in range(50):
544
+ val = self._betainc_integral(a, b, x)
545
+ diff = val - y
546
+ if t.max(t.abs(diff)) < 1e-13:
547
+ break
548
+ deriv = t.pow(t.clamp(x, 1e-300, 1 - 1e-300), a - 1) * \
549
+ t.pow(t.clamp(1 - x, 1e-300, 1 - 1e-300), b - 1) / beta_ab
550
+ deriv = t.clamp(deriv, 1e-300, 1e300)
551
+ step = diff / deriv
552
+
553
+ # Damped: backtracking to keep x in valid range
554
+ for _ in range(20):
555
+ x_new = x - step
556
+ if t.min(x_new) < 1e-15 or t.max(x_new) > 1.0 - 1e-15:
557
+ step = step * 0.5
558
+ else:
559
+ break
560
+
561
+ x = x - step
562
+ x = t.clamp(x, 1e-10, 1.0 - 1e-10)
563
+ return x
564
+
565
+ def gammainc(self, a, x):
566
+ return self._torch.special.gammainc(self._as_tensor(a), self._as_tensor(x))
567
+
568
+ def gammaincc(self, a, x):
569
+ return self._torch.special.gammaincc(self._as_tensor(a), self._as_tensor(x))
570
+
571
+ def gammaincinv(self, a, q):
572
+ t = self._torch
573
+ af = float(a)
574
+ qt = self._as_tensor(q)
575
+ if hasattr(t.special, "gammaincinv"):
576
+ return t.special.gammaincinv(self._as_tensor(a), qt)
577
+ return self._gammaincinv_newton(af, qt)
578
+
579
+ def _gammaincinv_newton(self, a, q):
580
+ """Inverse regularized lower incomplete gamma via damped Newton-Raphson."""
581
+ t = self._torch
582
+ device = q.device
583
+ q = t.clamp(q, 1e-15, 1 - 1e-15)
584
+ import math
585
+ at = t.tensor(a, dtype=t.float64, device=device)
586
+
587
+ # Wilson-Hilferty initial guess (much better than erfinv-based)
588
+ # For gamma(a,1): P(a,x) ≈ Φ((x/a)^(1/3) - (1 - 1/(9a))) / sqrt(1/(9a))
589
+ z = math.sqrt(2.0) * t.erfinv(2.0 * q - 1.0)
590
+ c = 1.0 - 1.0 / (9.0 * a)
591
+ s = 1.0 / math.sqrt(9.0 * a)
592
+ u = z * s + c
593
+ x = a * t.pow(u, 3.0)
594
+ x = t.clamp(x, 1e-10, 1e6)
595
+
596
+ lg_a = math.lgamma(a)
597
+ for _ in range(50):
598
+ val = t.special.gammainc(at, x)
599
+ diff = val - q
600
+ if t.max(t.abs(diff)) < 1e-13:
601
+ break
602
+ # derivative: d/dx P(a,x) = x^(a-1) * e^(-x) / Gamma(a)
603
+ log_deriv = (a - 1.0) * t.log(t.clamp(x, 1e-300, None)) - x - lg_a
604
+ deriv = t.exp(log_deriv)
605
+ deriv = t.clamp(deriv, 1e-300, 1e300)
606
+ step = diff / deriv
607
+
608
+ # Damped: backtracking line search to prevent oscillation
609
+ # Accept full step if it stays in bounds; otherwise halve
610
+ damped = False
611
+ for _ in range(20):
612
+ x_new = x - step
613
+ if t.min(x_new) < 1e-15 or t.max(x_new) > 2e6:
614
+ step = step * 0.5
615
+ damped = True
616
+ else:
617
+ break
618
+
619
+ x = x - step
620
+ x = t.clamp(x, 1e-10, 1e6)
621
+ return x
622
+
623
+ def gammaln(self, x):
624
+ return self._torch.lgamma(self._as_tensor(x))
625
+
626
+ def erf(self, x):
627
+ return self._torch.erf(self._as_tensor(x))
628
+
629
+ def erfc(self, x):
630
+ return self._torch.erfc(self._as_tensor(x))
631
+
632
+ def erfcinv(self, y):
633
+ t = self._torch
634
+ yt = self._as_tensor(y)
635
+ if hasattr(t.special, "erfcinv"):
636
+ return t.special.erfcinv(yt)
637
+ # Fallback: erfcinv(y) = erfinv(1 - y)
638
+ return t.erfinv(1.0 - yt)
639
+
640
+ def sqrt(self, x):
641
+ return self._torch.sqrt(self._as_tensor(x))
642
+
643
+ @property
644
+ def pi(self):
645
+ return self._torch.tensor(math.pi, dtype=self._torch.float64, device=self._device)
646
+
647
+ def clip(self, x, lo, hi):
648
+ return self._torch.clamp(x, lo, hi)
649
+
650
+ def where(self, cond, x, y):
651
+ t = self._torch
652
+ # torch.where requires boolean condition tensor
653
+ if isinstance(cond, t.Tensor) and cond.dtype != t.bool:
654
+ cond = cond.to(dtype=t.bool)
655
+ return t.where(cond, x, y)
656
+
657
+ def as_float64(self, x):
658
+ return self._as_tensor(x)
659
+
660
+
661
+ # =============================================================================
662
+ # SciPy / NumPy backend
663
+ # =============================================================================
664
+
665
+ class ScipySpecialFunctions:
666
+ """Special functions via scipy.special (pure NumPy / CPU).
667
+
668
+ Inverse functions (gammaincinv, betaincinv) use cached LUT + interpolation
669
+ for ~100ms evaluation on 1M points (vs ~3000ms for scipy's iterative solver).
670
+ Accuracy: ~1e-5 for typical parameter ranges.
671
+ For edge-case parameters (extreme a, b), falls back to scipy for full accuracy.
672
+ """
673
+
674
+ def __init__(self, *, use_lut: bool = True):
675
+ import scipy.special as scsp
676
+ self._scsp = scsp
677
+ self.use_lut = use_lut
678
+ # LUT cache for inverse functions (scalar a/b cases, well-behaved parameters)
679
+ self._gammaincinv_lut = {}
680
+ self._betaincinv_lut = {}
681
+
682
+ @staticmethod
683
+ @lru_cache(maxsize=256)
684
+ def _make_gammaincinv_lut(a, n_grid):
685
+ """Build LUT: x_grid -> y = gammainc(a, x_grid)."""
686
+ import scipy.special as _scsp
687
+ x_max = a + 20 * math.sqrt(max(a, 0.1)) + 10
688
+ x_max = min(x_max, 1e6)
689
+ n_log = n_grid // 3
690
+ n_lin = n_grid - n_log
691
+ x_lo = np.logspace(-15, math.log10(max(x_max, 1e-10)), n_log, endpoint=False)
692
+ x_hi = np.linspace(x_lo[-1] if len(x_lo) > 0 else 0, x_max, n_lin + 1)[1:]
693
+ x_grid = np.concatenate([x_lo, x_hi])
694
+ if len(x_grid) < n_grid:
695
+ extra = np.linspace(x_grid[-1], x_max, n_grid - len(x_grid) + 2)[1:]
696
+ x_grid = np.concatenate([x_grid, extra])
697
+ x_grid = x_grid[:n_grid]
698
+ y_grid = _scsp.gammainc(a, x_grid)
699
+ y_grid[0] = 0.0
700
+ y_grid[-1] = 1.0
701
+ return x_grid, y_grid
702
+
703
+ @staticmethod
704
+ @lru_cache(maxsize=256)
705
+ def _make_betaincinv_lut(a, b, n_grid):
706
+ """Build LUT: x_grid -> y = betainc(a, b, x_grid).
707
+
708
+ Uses log spacing near both boundaries for better precision when
709
+ a or b is small (e.g. b=0.5 for t/f distributions).
710
+ """
711
+ import scipy.special as _scsp
712
+ eps = 1e-15
713
+ # Log spacing: 40% near each boundary, 20% in the middle
714
+ n_edge = int(n_grid * 0.4)
715
+ n_mid = n_grid - 2 * n_edge
716
+ x_lo = np.logspace(np.log10(eps), np.log10(0.01), n_edge)
717
+ x_mid = np.linspace(0.01, 0.99, n_mid + 2)[1:-1]
718
+ x_hi = 1.0 - np.logspace(np.log10(eps), np.log10(0.01), n_edge)[::-1]
719
+ x_grid = np.concatenate([x_lo, x_mid, x_hi])
720
+ if len(x_grid) > n_grid:
721
+ x_grid = x_grid[:n_grid]
722
+ y_grid = _scsp.betainc(a, b, x_grid)
723
+ return x_grid, y_grid
724
+
725
+ @staticmethod
726
+ def _inverse_lut(q_or_y, x_grid, y_grid):
727
+ """Use LUT for inverse: given q, find x such that f(x) = q."""
728
+ idx = np.searchsorted(y_grid, q_or_y, side='left').clip(1, len(y_grid) - 1)
729
+ frac = (q_or_y - y_grid[idx - 1]) / (y_grid[idx] - y_grid[idx - 1] + 1e-300)
730
+ frac = np.clip(frac, 0.0, 1.0)
731
+ return x_grid[idx - 1] + frac * (x_grid[idx] - x_grid[idx - 1])
732
+
733
+ def betainc(self, a, b, x):
734
+ return self._scsp.betainc(a, b, np.asarray(x, dtype=np.float64))
735
+
736
+ def betaincinv(self, a, b, y):
737
+ arr = np.asarray(y, dtype=np.float64)
738
+ try:
739
+ af, bf = float(a), float(b)
740
+ except (TypeError, ValueError):
741
+ return self._scsp.betaincinv(a, b, arr)
742
+ if not self.use_lut:
743
+ return self._scsp.betaincinv(af, bf, arr)
744
+ if af < 0.3 or bf < 0.3 or af > 50 or bf > 50 or abs(af - bf) > 30:
745
+ return self._scsp.betaincinv(af, bf, arr)
746
+ # LUT + 1-step Newton refinement
747
+ key = (af, bf)
748
+ if key not in self._betaincinv_lut:
749
+ x_grid, y_grid = self._make_betaincinv_lut(af, bf, 20000)
750
+ self._betaincinv_lut[key] = (x_grid, y_grid)
751
+ x_grid, y_grid = self._betaincinv_lut[key]
752
+ x0 = self._inverse_lut(arr, x_grid, y_grid)
753
+ # 1 step of Newton
754
+ log_beta = math.lgamma(af) + math.lgamma(bf) - math.lgamma(af + bf)
755
+ p = self._scsp.betainc(af, bf, x0)
756
+ diff = p - arr
757
+ log_deriv = (af - 1.0) * np.log(np.clip(x0, 1e-300, None)) + \
758
+ (bf - 1.0) * np.log(np.clip(1.0 - x0, 1e-300, None)) - log_beta
759
+ deriv = np.exp(log_deriv)
760
+ x1 = x0 - diff / np.clip(deriv, 1e-300, 1e300)
761
+ return np.clip(x1, 1e-15, 1.0 - 1e-15)
762
+
763
+ def gammainc(self, a, x):
764
+ return self._scsp.gammainc(np.asarray(a, dtype=np.float64),
765
+ np.asarray(x, dtype=np.float64))
766
+
767
+ def gammaincc(self, a, x):
768
+ return self._scsp.gammaincc(np.asarray(a, dtype=np.float64),
769
+ np.asarray(x, dtype=np.float64))
770
+
771
+ def gammaincinv(self, a, q):
772
+ arr = np.asarray(q, dtype=np.float64)
773
+ try:
774
+ af = float(a)
775
+ except (TypeError, ValueError):
776
+ return self._scsp.gammaincinv(a, arr)
777
+ if not self.use_lut:
778
+ return self._scsp.gammaincinv(af, arr)
779
+ if af < 1.0:
780
+ return self._scsp.gammaincinv(af, arr)
781
+ # LUT + 1-step Newton refinement
782
+ key = (af,)
783
+ if key not in self._gammaincinv_lut:
784
+ x_grid, y_grid = self._make_gammaincinv_lut(af, 20000)
785
+ self._gammaincinv_lut[key] = (x_grid, y_grid)
786
+ x_grid, y_grid = self._gammaincinv_lut[key]
787
+ x0 = self._inverse_lut(arr, x_grid, y_grid)
788
+ # 1 step of Newton: x = x0 - (P(a, x0) - q) / P'(a, x0)
789
+ log_ga = math.lgamma(af)
790
+ p = self._scsp.gammainc(af, x0)
791
+ diff = p - arr
792
+ log_deriv = (af - 1.0) * np.log(np.clip(x0, 1e-300, None)) - x0 - log_ga
793
+ deriv = np.exp(log_deriv)
794
+ x1 = x0 - diff / np.clip(deriv, 1e-300, 1e300)
795
+ return np.clip(x1, 1e-15, 1e6)
796
+
797
+ def gammaln(self, x):
798
+ return self._scsp.gammaln(np.asarray(x, dtype=np.float64))
799
+
800
+ def erf(self, x):
801
+ return self._scsp.erf(np.asarray(x, dtype=np.float64))
802
+
803
+ def erfc(self, x):
804
+ return self._scsp.erfc(np.asarray(x, dtype=np.float64))
805
+
806
+ def erfcinv(self, y):
807
+ return self._scsp.erfcinv(np.asarray(y, dtype=np.float64))
808
+
809
+ def sqrt(self, x):
810
+ return np.sqrt(np.asarray(x, dtype=np.float64))
811
+
812
+ @property
813
+ def pi(self):
814
+ return np.pi
815
+
816
+ def clip(self, x, lo, hi):
817
+ return np.clip(x, lo, hi)
818
+
819
+ def where(self, cond, x, y):
820
+ return np.where(cond, x, y)
821
+
822
+ def as_float64(self, x):
823
+ return np.asarray(x, dtype=np.float64)
824
+
825
+
826
+ # =============================================================================
827
+ # Distribution base classes — parameterized by SpecialFunctions
828
+ # =============================================================================
829
+
830
+ _T_PPF_BISECT_LOWER = -64.0
831
+ _T_PPF_BISECT_UPPER = 64.0
832
+
833
+
834
+ class NormDistributionBase:
835
+ """scipy.stats.norm-like distribution, parameterized by SpecialFunctions."""
836
+
837
+ def __init__(self, sf: SpecialFunctions):
838
+ self._sf = sf
839
+
840
+ def _cdf_standard(self, x):
841
+ sf = self._sf
842
+ return 0.5 * (1.0 + sf.erf(x / sf.sqrt(2.0)))
843
+
844
+ def _sf_standard(self, x):
845
+ return sf_safe_mul(self._sf.erfc(x / self._sf.sqrt(2.0)), 0.5, self._sf)
846
+
847
+ def _ppf_standard(self, q):
848
+ return -self._sf.sqrt(2.0) * self._sf.erfcinv(2.0 * q)
849
+
850
+ def _isf_standard(self, q):
851
+ return self._ppf_standard(1.0 - q)
852
+
853
+ def _two_sided_pvalue_standard(self, stat_abs):
854
+ sf = self._sf
855
+ return sf.clip(2.0 * self._sf_standard(sf.as_float64(stat_abs)), 0.0, 1.0)
856
+
857
+ def _two_sided_critical_value_standard(self, alpha):
858
+ sf = self._sf
859
+ a = float(alpha)
860
+ if not (0.0 < a < 1.0):
861
+ return sf.as_float64(float("nan"))
862
+ return self._ppf_standard(1.0 - a / 2.0)
863
+
864
+ def cdf(self, x, *, loc=0.0, scale=1.0):
865
+ sf = self._sf
866
+ scale_f = float(scale)
867
+ if scale_f <= 0:
868
+ return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
869
+ x_std = (sf.as_float64(x) - float(loc)) / scale_f
870
+ return self._cdf_standard(x_std)
871
+
872
+ def sf(self, x, *, loc=0.0, scale=1.0):
873
+ sf = self._sf
874
+ scale_f = float(scale)
875
+ if scale_f <= 0:
876
+ return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
877
+ x_std = (sf.as_float64(x) - float(loc)) / scale_f
878
+ return self._sf_standard(x_std)
879
+
880
+ def ppf(self, q, *, loc=0.0, scale=1.0):
881
+ sf = self._sf
882
+ q_f = sf.as_float64(q)
883
+ scale_f = float(scale)
884
+ if scale_f <= 0:
885
+ return sf.where(q_f * 0 + 1, float("nan"), float("nan"))
886
+ return float(loc) + scale_f * self._ppf_standard(q_f)
887
+
888
+ def isf(self, q, *, loc=0.0, scale=1.0):
889
+ sf = self._sf
890
+ q_f = sf.as_float64(q)
891
+ scale_f = float(scale)
892
+ if scale_f <= 0:
893
+ return sf.where(q_f * 0 + 1, float("nan"), float("nan"))
894
+ return float(loc) + scale_f * self._isf_standard(q_f)
895
+
896
+ def pdf(self, x, *, loc=0.0, scale=1.0):
897
+ sf = self._sf
898
+ scale_f = float(scale)
899
+ x_f = sf.as_float64(x)
900
+ if scale_f <= 0:
901
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
902
+ z = (x_f - float(loc)) / scale_f
903
+ norm_const = sf.sqrt(2.0 * sf.pi)
904
+ return sf.exp(-0.5 * sf.square(z)) / (scale_f * norm_const)
905
+
906
+ def two_sided_pvalue(self, stat_abs):
907
+ return self._two_sided_pvalue_standard(stat_abs)
908
+
909
+ def two_sided_critical_value(self, alpha):
910
+ return self._two_sided_critical_value_standard(alpha)
911
+
912
+ def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
913
+ return _rvs_normal(self._sf, size=size, loc=loc, scale=scale)
914
+
915
+
916
+ class TDistributionBase:
917
+ """scipy.stats.t-like distribution, parameterized by SpecialFunctions."""
918
+
919
+ def __init__(self, sf: SpecialFunctions):
920
+ self._sf = sf
921
+
922
+ def _cdf_standard(self, x, df):
923
+ sf = self._sf
924
+ df_val = float(df)
925
+ if df_val <= 0:
926
+ return sf.where(x * 0 + 1, float("nan"), float("nan"))
927
+ z = df_val / (df_val + sf.square(sf.abs(x)))
928
+ ibeta = sf.betainc(df_val / 2.0, 0.5, z)
929
+ lower_tail = 0.5 * ibeta
930
+ return sf.where(x >= 0.0, 1.0 - lower_tail, lower_tail)
931
+
932
+ def _sf_standard(self, x, df):
933
+ return sf_safe_sub(1.0, self._cdf_standard(x, df), self._sf)
934
+
935
+ def _two_sided_pvalue_standard(self, stat_abs, df):
936
+ sf = self._sf
937
+ df_val = float(df)
938
+ if df_val <= 0:
939
+ return sf.where(stat_abs * 0 + 1, float("nan"), float("nan"))
940
+ z = df_val / (df_val + sf.square(sf.abs(stat_abs)))
941
+ return sf.betainc(df_val / 2.0, 0.5, z)
942
+
943
+ def _ppf_standard(self, q, df, *, max_bisect_steps=60):
944
+ sf = self._sf
945
+ df_val = float(df)
946
+ if df_val <= 0:
947
+ return sf.where(sf.as_float64(q) * 0 + 1, float("nan"), float("nan"))
948
+
949
+ q_f = sf.as_float64(q)
950
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
951
+ out = sf.where(q_f == 0.0, -float("inf"), out)
952
+ out = sf.where(q_f == 1.0, float("inf"), out)
953
+
954
+ valid = (q_f > 0.0) & (q_f < 1.0)
955
+ if not bool(sf.any(valid)):
956
+ return out
957
+
958
+ try:
959
+ tail = sf.minimum(q_f, 1.0 - q_f)
960
+ y = 2.0 * tail
961
+ y_inv = sf.betaincinv(df_val / 2.0, 0.5, y)
962
+ x_pos = sf.sqrt(df_val * (1.0 - y_inv) / y_inv)
963
+ quant = sf.where(q_f >= 0.5, x_pos, -x_pos)
964
+ return sf.where(valid, quant, out)
965
+ except Exception:
966
+ return self._ppf_bisect(q_f, df_val, valid, out, max_bisect_steps)
967
+
968
+ def _ppf_bisect(self, q, df_val, valid, out, steps):
969
+ sf = self._sf
970
+ lo = sf.where(q * 0 + 1, _T_PPF_BISECT_LOWER, _T_PPF_BISECT_LOWER)
971
+ hi = sf.where(q * 0 + 1, _T_PPF_BISECT_UPPER, _T_PPF_BISECT_UPPER)
972
+ for _ in range(max(int(steps), 1)):
973
+ mid = 0.5 * (lo + hi)
974
+ cdf_mid = self._cdf_standard(mid, df_val)
975
+ go_right = cdf_mid < q
976
+ lo = sf.where(go_right, mid, lo)
977
+ hi = sf.where(go_right, hi, mid)
978
+ quant = 0.5 * (lo + hi)
979
+ return sf.where(valid, quant, out)
980
+
981
+ def cdf(self, x, df, *, loc=0.0, scale=1.0):
982
+ sf = self._sf
983
+ scale_f = float(scale)
984
+ if scale_f <= 0:
985
+ return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
986
+ x_std = (sf.as_float64(x) - float(loc)) / scale_f
987
+ return self._cdf_standard(x_std, df)
988
+
989
+ def sf(self, x, df, *, loc=0.0, scale=1.0):
990
+ sf = self._sf
991
+ scale_f = float(scale)
992
+ if scale_f <= 0:
993
+ return sf.where(sf.as_float64(x) * 0 + 1, float("nan"), float("nan"))
994
+ x_std = (sf.as_float64(x) - float(loc)) / scale_f
995
+ return self._sf_standard(x_std, df)
996
+
997
+ def ppf(self, q, df, *, loc=0.0, scale=1.0, max_bisect_steps=60):
998
+ sf = self._sf
999
+ scale_f = float(scale)
1000
+ if scale_f <= 0:
1001
+ return sf.where(sf.as_float64(q) * 0 + 1, float("nan"), float("nan"))
1002
+ return float(loc) + scale_f * self._ppf_standard(q, df, max_bisect_steps=max_bisect_steps)
1003
+
1004
+ def isf(self, q, df, *, loc=0.0, scale=1.0, max_bisect_steps=60):
1005
+ sf = self._sf
1006
+ scale_f = float(scale)
1007
+ if scale_f <= 0:
1008
+ return sf.where(sf.as_float64(q) * 0 + 1, float("nan"), float("nan"))
1009
+ return float(loc) + scale_f * self._ppf_standard(1.0 - sf.as_float64(q), df, max_bisect_steps=max_bisect_steps)
1010
+
1011
+ def pdf(self, x, df, *, loc=0.0, scale=1.0):
1012
+ sf = self._sf
1013
+ x_f = sf.as_float64(x)
1014
+ df_val = float(df)
1015
+ scale_f = float(scale)
1016
+ if df_val <= 0.0 or scale_f <= 0.0:
1017
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1018
+ z = (x_f - float(loc)) / scale_f
1019
+ half_nu = df_val / 2.0
1020
+ log_coef = (
1021
+ sf.gammaln((df_val + 1.0) / 2.0)
1022
+ - sf.gammaln(half_nu)
1023
+ - 0.5 * (sf.log(df_val) + sf.log(sf.pi))
1024
+ )
1025
+ log_pdf = (
1026
+ log_coef
1027
+ - ((df_val + 1.0) / 2.0) * sf.log1p(sf.square(z) / df_val)
1028
+ - sf.log(scale_f)
1029
+ )
1030
+ return sf.exp(log_pdf)
1031
+
1032
+ def two_sided_pvalue(self, stat_abs, df):
1033
+ return self._two_sided_pvalue_standard(stat_abs, df)
1034
+
1035
+ def two_sided_critical_value(self, alpha, df, *, max_bisect_steps=60):
1036
+ sf = self._sf
1037
+ a = float(alpha)
1038
+ if not (0.0 < a < 1.0):
1039
+ return sf.as_float64(float("nan"))
1040
+ return self._ppf_standard(1.0 - a / 2.0, df, max_bisect_steps=max_bisect_steps)
1041
+
1042
+ def rvs(self, df, *, size=None, loc=0.0, scale=1.0, dtype=None):
1043
+ return _rvs_t(self._sf, df=df, size=size, loc=loc, scale=scale)
1044
+
1045
+
1046
+ class UniformDistributionBase:
1047
+ def __init__(self, sf: SpecialFunctions):
1048
+ self._sf = sf
1049
+
1050
+ def cdf(self, x, *, loc=0.0, scale=1.0):
1051
+ sf = self._sf
1052
+ scale_f = float(scale)
1053
+ x_f = sf.as_float64(x)
1054
+ if scale_f <= 0.0:
1055
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1056
+ z = (x_f - float(loc)) / scale_f
1057
+ return sf.clip(z, 0.0, 1.0)
1058
+
1059
+ def sf(self, x, *, loc=0.0, scale=1.0):
1060
+ return sf_safe_sub(1.0, self.cdf(x, loc=loc, scale=scale), self._sf)
1061
+
1062
+ def ppf(self, q, *, loc=0.0, scale=1.0):
1063
+ sf = self._sf
1064
+ scale_f = float(scale)
1065
+ q_f = sf.as_float64(q)
1066
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1067
+ if scale_f <= 0.0:
1068
+ return out
1069
+ valid = (q_f >= 0.0) & (q_f <= 1.0)
1070
+ return sf.where(valid, float(loc) + scale_f * q_f, out)
1071
+
1072
+ def isf(self, q, *, loc=0.0, scale=1.0):
1073
+ return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
1074
+
1075
+ def pdf(self, x, *, loc=0.0, scale=1.0):
1076
+ sf = self._sf
1077
+ scale_f = float(scale)
1078
+ x_f = sf.as_float64(x)
1079
+ if scale_f <= 0.0:
1080
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1081
+ z = (x_f - float(loc)) / scale_f
1082
+ in_support = (z >= 0.0) & (z <= 1.0)
1083
+ return sf.where(in_support, 1.0 / scale_f, 0.0)
1084
+
1085
+ def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
1086
+ return _rvs_uniform(self._sf, size=size, loc=loc, scale=scale)
1087
+
1088
+
1089
+ class ExponDistributionBase:
1090
+ def __init__(self, sf: SpecialFunctions):
1091
+ self._sf = sf
1092
+
1093
+ def cdf(self, x, *, loc=0.0, scale=1.0):
1094
+ sf = self._sf
1095
+ x_f = sf.as_float64(x)
1096
+ scale_f = float(scale)
1097
+ if scale_f <= 0.0:
1098
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1099
+ z = (x_f - float(loc)) / scale_f
1100
+ return sf.where(z <= 0.0, 0.0, 1.0 - sf.exp(-z))
1101
+
1102
+ def sf(self, x, *, loc=0.0, scale=1.0):
1103
+ sf = self._sf
1104
+ x_f = sf.as_float64(x)
1105
+ scale_f = float(scale)
1106
+ if scale_f <= 0.0:
1107
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1108
+ z = (x_f - float(loc)) / scale_f
1109
+ return sf.where(z <= 0.0, 1.0, sf.exp(-z))
1110
+
1111
+ def ppf(self, q, *, loc=0.0, scale=1.0):
1112
+ sf = self._sf
1113
+ scale_f = float(scale)
1114
+ q_f = sf.as_float64(q)
1115
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1116
+ if scale_f <= 0.0:
1117
+ return out
1118
+ out = sf.where(q_f == 0.0, float(loc), out)
1119
+ out = sf.where(q_f == 1.0, float("inf"), out)
1120
+ valid = (q_f > 0.0) & (q_f < 1.0)
1121
+ return sf.where(valid, float(loc) - scale_f * sf.log1p(-q_f), out)
1122
+
1123
+ def isf(self, q, *, loc=0.0, scale=1.0):
1124
+ return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
1125
+
1126
+ def pdf(self, x, *, loc=0.0, scale=1.0):
1127
+ sf = self._sf
1128
+ x_f = sf.as_float64(x)
1129
+ scale_f = float(scale)
1130
+ if scale_f <= 0.0:
1131
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1132
+ z = (x_f - float(loc)) / scale_f
1133
+ return sf.where(z >= 0.0, sf.exp(-z) / scale_f, 0.0)
1134
+
1135
+ def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
1136
+ return _rvs_expon(self._sf, size=size, loc=loc, scale=scale)
1137
+
1138
+
1139
+ class CauchyDistributionBase:
1140
+ def __init__(self, sf: SpecialFunctions):
1141
+ self._sf = sf
1142
+
1143
+ def cdf(self, x, *, loc=0.0, scale=1.0):
1144
+ sf = self._sf
1145
+ scale_f = float(scale)
1146
+ x_f = sf.as_float64(x)
1147
+ if scale_f <= 0.0:
1148
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1149
+ z = (x_f - float(loc)) / scale_f
1150
+ return 0.5 + sf.atan(z) / sf.pi
1151
+
1152
+ def sf(self, x, *, loc=0.0, scale=1.0):
1153
+ return sf_safe_sub(1.0, self.cdf(x, loc=loc, scale=scale), self._sf)
1154
+
1155
+ def ppf(self, q, *, loc=0.0, scale=1.0):
1156
+ sf = self._sf
1157
+ scale_f = float(scale)
1158
+ q_f = sf.as_float64(q)
1159
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1160
+ if scale_f <= 0.0:
1161
+ return out
1162
+ out = sf.where(q_f == 0.0, -float("inf"), out)
1163
+ out = sf.where(q_f == 1.0, float("inf"), out)
1164
+ valid = (q_f > 0.0) & (q_f < 1.0)
1165
+ return sf.where(valid, float(loc) + scale_f * sf.tan(sf.pi * (q_f - 0.5)), out)
1166
+
1167
+ def isf(self, q, *, loc=0.0, scale=1.0):
1168
+ return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
1169
+
1170
+ def pdf(self, x, *, loc=0.0, scale=1.0):
1171
+ sf = self._sf
1172
+ scale_f = float(scale)
1173
+ x_f = sf.as_float64(x)
1174
+ if scale_f <= 0.0:
1175
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1176
+ z = (x_f - float(loc)) / scale_f
1177
+ return 1.0 / (sf.pi * scale_f * (1.0 + sf.square(z)))
1178
+
1179
+ def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
1180
+ return _rvs_cauchy(self._sf, size=size, loc=loc, scale=scale)
1181
+
1182
+
1183
+ class LaplaceDistributionBase:
1184
+ def __init__(self, sf: SpecialFunctions):
1185
+ self._sf = sf
1186
+
1187
+ def cdf(self, x, *, loc=0.0, scale=1.0):
1188
+ sf = self._sf
1189
+ scale_f = float(scale)
1190
+ x_f = sf.as_float64(x)
1191
+ if scale_f <= 0.0:
1192
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1193
+ z = (x_f - float(loc)) / scale_f
1194
+ return sf.where(z < 0.0, 0.5 * sf.exp(z), 1.0 - 0.5 * sf.exp(-z))
1195
+
1196
+ def sf(self, x, *, loc=0.0, scale=1.0):
1197
+ sf = self._sf
1198
+ scale_f = float(scale)
1199
+ x_f = sf.as_float64(x)
1200
+ if scale_f <= 0.0:
1201
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1202
+ z = (x_f - float(loc)) / scale_f
1203
+ return sf.where(z < 0.0, 1.0 - 0.5 * sf.exp(z), 0.5 * sf.exp(-z))
1204
+
1205
+ def ppf(self, q, *, loc=0.0, scale=1.0):
1206
+ sf = self._sf
1207
+ scale_f = float(scale)
1208
+ q_f = sf.as_float64(q)
1209
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1210
+ if scale_f <= 0.0:
1211
+ return out
1212
+ out = sf.where(q_f == 0.0, -float("inf"), out)
1213
+ out = sf.where(q_f == 1.0, float("inf"), out)
1214
+ lower = (q_f > 0.0) & (q_f < 0.5)
1215
+ upper = (q_f >= 0.5) & (q_f < 1.0)
1216
+ out = sf.where(lower, float(loc) + scale_f * sf.log(2.0 * q_f), out)
1217
+ out = sf.where(upper, float(loc) - scale_f * sf.log(2.0 * (1.0 - q_f)), out)
1218
+ return out
1219
+
1220
+ def isf(self, q, *, loc=0.0, scale=1.0):
1221
+ return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
1222
+
1223
+ def pdf(self, x, *, loc=0.0, scale=1.0):
1224
+ sf = self._sf
1225
+ scale_f = float(scale)
1226
+ x_f = sf.as_float64(x)
1227
+ if scale_f <= 0.0:
1228
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1229
+ z = sf.abs((x_f - float(loc)) / scale_f)
1230
+ return 0.5 * sf.exp(-z) / scale_f
1231
+
1232
+ def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
1233
+ return _rvs_laplace(self._sf, size=size, loc=loc, scale=scale)
1234
+
1235
+
1236
+ class LogisticDistributionBase:
1237
+ def __init__(self, sf: SpecialFunctions):
1238
+ self._sf = sf
1239
+
1240
+ def cdf(self, x, *, loc=0.0, scale=1.0):
1241
+ sf = self._sf
1242
+ scale_f = float(scale)
1243
+ x_f = sf.as_float64(x)
1244
+ if scale_f <= 0.0:
1245
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1246
+ z = (x_f - float(loc)) / scale_f
1247
+ return 1.0 / (1.0 + sf.exp(-z))
1248
+
1249
+ def sf(self, x, *, loc=0.0, scale=1.0):
1250
+ sf = self._sf
1251
+ scale_f = float(scale)
1252
+ x_f = sf.as_float64(x)
1253
+ if scale_f <= 0.0:
1254
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1255
+ z = (x_f - float(loc)) / scale_f
1256
+ return 1.0 / (1.0 + sf.exp(z))
1257
+
1258
+ def ppf(self, q, *, loc=0.0, scale=1.0):
1259
+ sf = self._sf
1260
+ scale_f = float(scale)
1261
+ q_f = sf.as_float64(q)
1262
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1263
+ if scale_f <= 0.0:
1264
+ return out
1265
+ out = sf.where(q_f == 0.0, -float("inf"), out)
1266
+ out = sf.where(q_f == 1.0, float("inf"), out)
1267
+ valid = (q_f > 0.0) & (q_f < 1.0)
1268
+ return sf.where(valid, float(loc) + scale_f * sf.log(q_f / (1.0 - q_f)), out)
1269
+
1270
+ def isf(self, q, *, loc=0.0, scale=1.0):
1271
+ return self.ppf(1.0 - self._sf.as_float64(q), loc=loc, scale=scale)
1272
+
1273
+ def pdf(self, x, *, loc=0.0, scale=1.0):
1274
+ sf = self._sf
1275
+ cdf_x = self.cdf(x, loc=loc, scale=scale)
1276
+ scale_f = float(scale)
1277
+ if scale_f <= 0.0:
1278
+ return cdf_x
1279
+ return cdf_x * (1.0 - cdf_x) / scale_f
1280
+
1281
+ def rvs(self, *, size=None, loc=0.0, scale=1.0, dtype=None):
1282
+ return _rvs_logistic(self._sf, size=size, loc=loc, scale=scale)
1283
+
1284
+
1285
+ class Chi2DistributionBase:
1286
+ def __init__(self, sf: SpecialFunctions):
1287
+ self._sf = sf
1288
+
1289
+ def cdf(self, x, df):
1290
+ sf = self._sf
1291
+ x_f = sf.as_float64(x)
1292
+ df_f = float(df)
1293
+ if df_f <= 0.0:
1294
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1295
+ y = x_f / 2.0
1296
+ return sf.where(x_f <= 0.0, 0.0, sf.gammainc(df_f / 2.0, y))
1297
+
1298
+ def sf(self, x, df):
1299
+ sf = self._sf
1300
+ x_f = sf.as_float64(x)
1301
+ df_f = float(df)
1302
+ if df_f <= 0.0:
1303
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1304
+ y = x_f / 2.0
1305
+ return sf.where(x_f <= 0.0, 1.0, sf.gammaincc(df_f / 2.0, y))
1306
+
1307
+ def ppf(self, q, df):
1308
+ sf = self._sf
1309
+ q_f = sf.as_float64(q)
1310
+ df_f = float(df)
1311
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1312
+ if df_f <= 0.0:
1313
+ return out
1314
+ out = sf.where(q_f == 0.0, 0.0, out)
1315
+ out = sf.where(q_f == 1.0, float("inf"), out)
1316
+ valid = (q_f > 0.0) & (q_f < 1.0)
1317
+ return sf.where(valid, 2.0 * sf.gammaincinv(df_f / 2.0, q_f), out)
1318
+
1319
+ def isf(self, q, df):
1320
+ return self.ppf(1.0 - self._sf.as_float64(q), df)
1321
+
1322
+ def pdf(self, x, df):
1323
+ sf = self._sf
1324
+ x_f = sf.as_float64(x)
1325
+ df_f = float(df)
1326
+ if df_f <= 0.0:
1327
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1328
+ y = sf.maximum(x_f, 1e-300)
1329
+ logpdf = ((df_f / 2.0) - 1.0) * sf.log(y) - y / 2.0 - (df_f / 2.0) * sf.log(2.0) - sf.gammaln(df_f / 2.0)
1330
+ return sf.where(x_f > 0.0, sf.exp(logpdf), 0.0)
1331
+
1332
+ def rvs(self, df, *, size=None, dtype=None):
1333
+ return _rvs_chi2(self._sf, df=df, size=size)
1334
+
1335
+
1336
+ class GammaDistributionBase:
1337
+ def __init__(self, sf: SpecialFunctions):
1338
+ self._sf = sf
1339
+
1340
+ def cdf(self, x, a, *, loc=0.0, scale=1.0):
1341
+ sf = self._sf
1342
+ x_f = sf.as_float64(x)
1343
+ a_f = float(a)
1344
+ scale_f = float(scale)
1345
+ if a_f <= 0.0 or scale_f <= 0.0:
1346
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1347
+ y = (x_f - float(loc)) / scale_f
1348
+ return sf.where(y <= 0.0, 0.0, sf.gammainc(a_f, y))
1349
+
1350
+ def sf(self, x, a, *, loc=0.0, scale=1.0):
1351
+ sf = self._sf
1352
+ x_f = sf.as_float64(x)
1353
+ a_f = float(a)
1354
+ scale_f = float(scale)
1355
+ if a_f <= 0.0 or scale_f <= 0.0:
1356
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1357
+ y = (x_f - float(loc)) / scale_f
1358
+ return sf.where(y <= 0.0, 1.0, sf.gammaincc(a_f, y))
1359
+
1360
+ def ppf(self, q, a, *, loc=0.0, scale=1.0):
1361
+ sf = self._sf
1362
+ q_f = sf.as_float64(q)
1363
+ a_f = float(a)
1364
+ scale_f = float(scale)
1365
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1366
+ if a_f <= 0.0 or scale_f <= 0.0:
1367
+ return out
1368
+ out = sf.where(q_f == 0.0, float(loc), out)
1369
+ out = sf.where(q_f == 1.0, float("inf"), out)
1370
+ valid = (q_f > 0.0) & (q_f < 1.0)
1371
+ return sf.where(valid, float(loc) + scale_f * sf.gammaincinv(a_f, q_f), out)
1372
+
1373
+ def isf(self, q, a, *, loc=0.0, scale=1.0):
1374
+ return self.ppf(1.0 - self._sf.as_float64(q), a, loc=loc, scale=scale)
1375
+
1376
+ def pdf(self, x, a, *, loc=0.0, scale=1.0):
1377
+ sf = self._sf
1378
+ x_f = sf.as_float64(x)
1379
+ a_f = float(a)
1380
+ scale_f = float(scale)
1381
+ if a_f <= 0.0 or scale_f <= 0.0:
1382
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1383
+ y = (x_f - float(loc)) / scale_f
1384
+ y_safe = sf.maximum(y, 1e-300)
1385
+ logpdf = (a_f - 1.0) * sf.log(y_safe) - y_safe - sf.gammaln(a_f) - sf.log(scale_f)
1386
+ return sf.where(y > 0.0, sf.exp(logpdf), 0.0)
1387
+
1388
+ def rvs(self, a, *, size=None, loc=0.0, scale=1.0, dtype=None):
1389
+ return _rvs_gamma(self._sf, a=a, size=size, loc=loc, scale=scale)
1390
+
1391
+
1392
+ class BetaDistributionBase:
1393
+ def __init__(self, sf: SpecialFunctions):
1394
+ self._sf = sf
1395
+
1396
+ def cdf(self, x, a, b, *, loc=0.0, scale=1.0):
1397
+ sf = self._sf
1398
+ x_f = sf.as_float64(x)
1399
+ a_f = float(a)
1400
+ b_f = float(b)
1401
+ scale_f = float(scale)
1402
+ if a_f <= 0.0 or b_f <= 0.0 or scale_f <= 0.0:
1403
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1404
+ y = (x_f - float(loc)) / scale_f
1405
+ core = sf.betainc(a_f, b_f, sf.clip(y, 0.0, 1.0))
1406
+ out = sf.where(y <= 0.0, 0.0, core)
1407
+ return sf.where(y >= 1.0, 1.0, out)
1408
+
1409
+ def sf(self, x, a, b, *, loc=0.0, scale=1.0):
1410
+ return sf_safe_sub(1.0, self.cdf(x, a, b, loc=loc, scale=scale), self._sf)
1411
+
1412
+ def ppf(self, q, a, b, *, loc=0.0, scale=1.0):
1413
+ sf = self._sf
1414
+ q_f = sf.as_float64(q)
1415
+ a_f = float(a)
1416
+ b_f = float(b)
1417
+ scale_f = float(scale)
1418
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1419
+ if a_f <= 0.0 or b_f <= 0.0 or scale_f <= 0.0:
1420
+ return out
1421
+ out = sf.where(q_f == 0.0, float(loc), out)
1422
+ out = sf.where(q_f == 1.0, float(loc) + scale_f, out)
1423
+ valid = (q_f > 0.0) & (q_f < 1.0)
1424
+ return sf.where(valid, float(loc) + scale_f * sf.betaincinv(a_f, b_f, q_f), out)
1425
+
1426
+ def isf(self, q, a, b, *, loc=0.0, scale=1.0):
1427
+ return self.ppf(1.0 - self._sf.as_float64(q), a, b, loc=loc, scale=scale)
1428
+
1429
+ def pdf(self, x, a, b, *, loc=0.0, scale=1.0):
1430
+ sf = self._sf
1431
+ x_f = sf.as_float64(x)
1432
+ a_f = float(a)
1433
+ b_f = float(b)
1434
+ scale_f = float(scale)
1435
+ if a_f <= 0.0 or b_f <= 0.0 or scale_f <= 0.0:
1436
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1437
+ y = (x_f - float(loc)) / scale_f
1438
+ y_safe = sf.clip(y, 1e-300, 1.0 - 1e-300)
1439
+ betaln = sf.gammaln(a_f) + sf.gammaln(b_f) - sf.gammaln(a_f + b_f)
1440
+ logpdf = (a_f - 1.0) * sf.log(y_safe) + (b_f - 1.0) * sf.log1p(-y_safe) - betaln - sf.log(scale_f)
1441
+ in_support = (y > 0.0) & (y < 1.0)
1442
+ return sf.where(in_support, sf.exp(logpdf), 0.0)
1443
+
1444
+ def rvs(self, a, b, *, size=None, loc=0.0, scale=1.0, dtype=None):
1445
+ return _rvs_beta(self._sf, a=a, b=b, size=size, loc=loc, scale=scale)
1446
+
1447
+
1448
+ class FDistributionBase:
1449
+ def __init__(self, sf: SpecialFunctions):
1450
+ self._sf = sf
1451
+
1452
+ def cdf(self, x, dfn, dfd):
1453
+ sf = self._sf
1454
+ x_f = sf.as_float64(x)
1455
+ dfn_f = float(dfn)
1456
+ dfd_f = float(dfd)
1457
+ if dfn_f <= 0.0 or dfd_f <= 0.0:
1458
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1459
+ z = (dfn_f * sf.maximum(x_f, 0.0)) / (dfn_f * sf.maximum(x_f, 0.0) + dfd_f)
1460
+ core = sf.betainc(dfn_f / 2.0, dfd_f / 2.0, z)
1461
+ return sf.where(x_f <= 0.0, 0.0, core)
1462
+
1463
+ def sf(self, x, dfn, dfd):
1464
+ return sf_safe_sub(1.0, self.cdf(x, dfn, dfd), self._sf)
1465
+
1466
+ def ppf(self, q, dfn, dfd):
1467
+ sf = self._sf
1468
+ q_f = sf.as_float64(q)
1469
+ dfn_f = float(dfn)
1470
+ dfd_f = float(dfd)
1471
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1472
+ if dfn_f <= 0.0 or dfd_f <= 0.0:
1473
+ return out
1474
+ out = sf.where(q_f == 0.0, 0.0, out)
1475
+ out = sf.where(q_f == 1.0, float("inf"), out)
1476
+ valid = (q_f > 0.0) & (q_f < 1.0)
1477
+ z = sf.betaincinv(dfn_f / 2.0, dfd_f / 2.0, q_f)
1478
+ return sf.where(valid, (dfd_f * z) / (dfn_f * (1.0 - z)), out)
1479
+
1480
+ def isf(self, q, dfn, dfd):
1481
+ return self.ppf(1.0 - self._sf.as_float64(q), dfn, dfd)
1482
+
1483
+ def pdf(self, x, dfn, dfd):
1484
+ sf = self._sf
1485
+ x_f = sf.as_float64(x)
1486
+ dfn_f = float(dfn)
1487
+ dfd_f = float(dfd)
1488
+ if dfn_f <= 0.0 or dfd_f <= 0.0:
1489
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1490
+ a = dfn_f / 2.0
1491
+ b = dfd_f / 2.0
1492
+ x_safe = sf.maximum(x_f, 1e-300)
1493
+ betaln = sf.gammaln(a) + sf.gammaln(b) - sf.gammaln(a + b)
1494
+ logpdf = a * sf.log(dfn_f / dfd_f) + (a - 1.0) * sf.log(x_safe) - betaln - (a + b) * sf.log1p((dfn_f / dfd_f) * x_safe)
1495
+ return sf.where(x_f > 0.0, sf.exp(logpdf), 0.0)
1496
+
1497
+ def rvs(self, dfn, dfd, *, size=None, dtype=None):
1498
+ return _rvs_f(self._sf, dfn=dfn, dfd=dfd, size=size)
1499
+
1500
+
1501
+ class WeibullMinDistributionBase:
1502
+ def __init__(self, sf: SpecialFunctions):
1503
+ self._sf = sf
1504
+
1505
+ def cdf(self, x, c, *, loc=0.0, scale=1.0):
1506
+ sf = self._sf
1507
+ x_f = sf.as_float64(x)
1508
+ c_f = float(c)
1509
+ scale_f = float(scale)
1510
+ if c_f <= 0.0 or scale_f <= 0.0:
1511
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1512
+ y = (x_f - float(loc)) / scale_f
1513
+ yc = sf.power(sf.maximum(y, 0.0), c_f)
1514
+ return sf.where(y <= 0.0, 0.0, 1.0 - sf.exp(-yc))
1515
+
1516
+ def sf(self, x, c, *, loc=0.0, scale=1.0):
1517
+ sf = self._sf
1518
+ x_f = sf.as_float64(x)
1519
+ c_f = float(c)
1520
+ scale_f = float(scale)
1521
+ if c_f <= 0.0 or scale_f <= 0.0:
1522
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1523
+ y = (x_f - float(loc)) / scale_f
1524
+ yc = sf.power(sf.maximum(y, 0.0), c_f)
1525
+ return sf.where(y <= 0.0, 1.0, sf.exp(-yc))
1526
+
1527
+ def ppf(self, q, c, *, loc=0.0, scale=1.0):
1528
+ sf = self._sf
1529
+ q_f = sf.as_float64(q)
1530
+ c_f = float(c)
1531
+ scale_f = float(scale)
1532
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1533
+ if c_f <= 0.0 or scale_f <= 0.0:
1534
+ return out
1535
+ out = sf.where(q_f == 0.0, float(loc), out)
1536
+ out = sf.where(q_f == 1.0, float("inf"), out)
1537
+ valid = (q_f > 0.0) & (q_f < 1.0)
1538
+ return sf.where(valid, float(loc) + scale_f * sf.power(-sf.log1p(-q_f), 1.0 / c_f), out)
1539
+
1540
+ def isf(self, q, c, *, loc=0.0, scale=1.0):
1541
+ return self.ppf(1.0 - self._sf.as_float64(q), c, loc=loc, scale=scale)
1542
+
1543
+ def pdf(self, x, c, *, loc=0.0, scale=1.0):
1544
+ sf = self._sf
1545
+ x_f = sf.as_float64(x)
1546
+ c_f = float(c)
1547
+ scale_f = float(scale)
1548
+ if c_f <= 0.0 or scale_f <= 0.0:
1549
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1550
+ y = (x_f - float(loc)) / scale_f
1551
+ y_pos = sf.maximum(y, 1e-300)
1552
+ logpdf = sf.log(c_f / scale_f) + (c_f - 1.0) * sf.log(y_pos) - sf.power(y_pos, c_f)
1553
+ return sf.where(y > 0.0, sf.exp(logpdf), 0.0)
1554
+
1555
+ def rvs(self, c, *, size=None, loc=0.0, scale=1.0, dtype=None):
1556
+ return _rvs_weibull(self._sf, c=c, size=size, loc=loc, scale=scale)
1557
+
1558
+
1559
+ class LognormDistributionBase:
1560
+ def __init__(self, sf: SpecialFunctions, norm_dist: NormDistributionBase):
1561
+ self._sf = sf
1562
+ self._norm = norm_dist
1563
+
1564
+ def cdf(self, x, s, *, loc=0.0, scale=1.0):
1565
+ sf = self._sf
1566
+ x_f = sf.as_float64(x)
1567
+ s_f = float(s)
1568
+ scale_f = float(scale)
1569
+ if s_f <= 0.0 or scale_f <= 0.0:
1570
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1571
+ y = (x_f - float(loc)) / scale_f
1572
+ z = sf.log(sf.maximum(y, 1e-300)) / s_f
1573
+ return sf.where(y <= 0.0, 0.0, self._norm._cdf_standard(z))
1574
+
1575
+ def sf(self, x, s, *, loc=0.0, scale=1.0):
1576
+ sf = self._sf
1577
+ x_f = sf.as_float64(x)
1578
+ s_f = float(s)
1579
+ scale_f = float(scale)
1580
+ if s_f <= 0.0 or scale_f <= 0.0:
1581
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1582
+ y = (x_f - float(loc)) / scale_f
1583
+ z = sf.log(sf.maximum(y, 1e-300)) / s_f
1584
+ return sf.where(y <= 0.0, 1.0, self._norm._sf_standard(z))
1585
+
1586
+ def ppf(self, q, s, *, loc=0.0, scale=1.0):
1587
+ sf = self._sf
1588
+ q_f = sf.as_float64(q)
1589
+ s_f = float(s)
1590
+ scale_f = float(scale)
1591
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1592
+ if s_f <= 0.0 or scale_f <= 0.0:
1593
+ return out
1594
+ out = sf.where(q_f == 0.0, float(loc), out)
1595
+ out = sf.where(q_f == 1.0, float("inf"), out)
1596
+ valid = (q_f > 0.0) & (q_f < 1.0)
1597
+ return sf.where(valid, float(loc) + scale_f * sf.exp(s_f * self._norm._ppf_standard(q_f)), out)
1598
+
1599
+ def isf(self, q, s, *, loc=0.0, scale=1.0):
1600
+ return self.ppf(1.0 - self._sf.as_float64(q), s, loc=loc, scale=scale)
1601
+
1602
+ def pdf(self, x, s, *, loc=0.0, scale=1.0):
1603
+ sf = self._sf
1604
+ x_f = sf.as_float64(x)
1605
+ s_f = float(s)
1606
+ scale_f = float(scale)
1607
+ if s_f <= 0.0 or scale_f <= 0.0:
1608
+ return sf.where(x_f * 0 + 1, float("nan"), float("nan"))
1609
+ y = (x_f - float(loc)) / scale_f
1610
+ y_pos = sf.maximum(y, 1e-300)
1611
+ z = sf.log(y_pos) / s_f
1612
+ logpdf = -0.5 * sf.square(z) - sf.log(y_pos * s_f * sf.sqrt(2.0 * sf.pi)) - sf.log(scale_f)
1613
+ return sf.where(y > 0.0, sf.exp(logpdf), 0.0)
1614
+
1615
+ def rvs(self, s, *, size=None, loc=0.0, scale=1.0, dtype=None):
1616
+ return _rvs_lognorm(self._sf, s=s, size=size, loc=loc, scale=scale)
1617
+
1618
+
1619
+ class PoissonDistributionBase:
1620
+ def __init__(self, sf: SpecialFunctions):
1621
+ self._sf = sf
1622
+
1623
+ def _ppf_search(self, q, mu):
1624
+ sf = self._sf
1625
+ q_f = sf.as_float64(q)
1626
+ mu_f = float(mu)
1627
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1628
+ if mu_f < 0.0:
1629
+ return out
1630
+ out = sf.where(q_f == 0.0, -1.0, out)
1631
+ out = sf.where(q_f == 1.0, float("inf"), out)
1632
+ valid = (q_f > 0.0) & (q_f < 1.0)
1633
+ if not bool(sf.any(valid)):
1634
+ return out
1635
+ hi0 = float(max(1.0, np.ceil(mu_f + 10.0 * np.sqrt(mu_f + 1.0) + 10.0)))
1636
+ low = sf.where(q_f * 0 + 1, -1.0, -1.0)
1637
+ high = sf.where(q_f * 0 + 1, hi0, hi0)
1638
+ for _ in range(16):
1639
+ cdf_high = sf.where(high < 0.0, 0.0, sf.gammaincc(high + 1.0, mu_f))
1640
+ need_expand = valid & (cdf_high < q_f)
1641
+ high = sf.where(need_expand, sf.maximum(high * 2.0 + 1.0, 1.0), high)
1642
+ max_high_f = float(np.max(sf.to_numpy(sf.where(valid, high, 0.0))))
1643
+ steps = int(np.ceil(np.log2(max(max_high_f + 2.0, 2.0)))) + 2
1644
+ for _ in range(max(1, steps)):
1645
+ mid = sf.floor((low + high) / 2.0)
1646
+ cdf_mid = sf.where(mid < 0.0, 0.0, sf.gammaincc(mid + 1.0, mu_f))
1647
+ move_right = valid & (cdf_mid < q_f)
1648
+ low = sf.where(move_right, mid, low)
1649
+ high = sf.where(valid & (~move_right), mid, high)
1650
+ k = sf.floor(high)
1651
+ cdf_k = sf.where(k < 0.0, 0.0, sf.gammaincc(k + 1.0, mu_f))
1652
+ k = sf.where(valid & (cdf_k < q_f), k + 1.0, k)
1653
+ km1 = k - 1.0
1654
+ cdf_km1 = sf.where(km1 < 0.0, 0.0, sf.gammaincc(k, mu_f))
1655
+ return sf.where(valid & (km1 >= -1.0) & (cdf_km1 >= q_f), km1, sf.where(valid, k, out))
1656
+
1657
+ def pmf(self, k, mu, *, loc=0):
1658
+ sf = self._sf
1659
+ k_f = sf.as_float64(k) - float(loc)
1660
+ mu_f = float(mu)
1661
+ if mu_f < 0.0:
1662
+ return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
1663
+ k_floor = sf.floor(k_f)
1664
+ is_int = (k_floor == k_f)
1665
+ valid = (k_f >= 0.0) & is_int
1666
+ k_safe = sf.maximum(k_floor, 0.0)
1667
+ logpmf = k_safe * sf.log(sf.maximum(mu_f, 1e-300)) - mu_f - sf.gammaln(k_safe + 1.0)
1668
+ return sf.where(valid, sf.exp(logpmf), 0.0)
1669
+
1670
+ def cdf(self, k, mu, *, loc=0):
1671
+ sf = self._sf
1672
+ k_f = sf.as_float64(k) - float(loc)
1673
+ mu_f = float(mu)
1674
+ if mu_f < 0.0:
1675
+ return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
1676
+ k_floor = sf.floor(k_f)
1677
+ return sf.where(k_floor < 0.0, 0.0, sf.gammaincc(k_floor + 1.0, mu_f))
1678
+
1679
+ def sf(self, k, mu, *, loc=0):
1680
+ return sf_safe_sub(1.0, self.cdf(k, mu, loc=loc), self._sf)
1681
+
1682
+ def ppf(self, q, mu, *, loc=0):
1683
+ sf = self._sf
1684
+ loc_f = float(loc)
1685
+ q_f = sf.as_float64(q)
1686
+ return self._ppf_search(q_f, mu) + loc_f
1687
+
1688
+ def isf(self, q, mu, *, loc=0):
1689
+ return self.ppf(1.0 - self._sf.as_float64(q), mu, loc=loc)
1690
+
1691
+ def rvs(self, mu, *, size=None, loc=0, dtype=None):
1692
+ return _rvs_poisson(self._sf, mu=mu, size=size, loc=loc)
1693
+
1694
+
1695
+ class BinomDistributionBase:
1696
+ def __init__(self, sf: SpecialFunctions):
1697
+ self._sf = sf
1698
+
1699
+ def _ppf_search(self, q, n, p):
1700
+ sf = self._sf
1701
+ q_f = sf.as_float64(q)
1702
+ n_i = int(n)
1703
+ p_f = float(p)
1704
+ out = sf.where(q_f * 0 + 1, float("nan"), float("nan"))
1705
+ if n_i < 0 or p_f < 0.0 or p_f > 1.0:
1706
+ return out
1707
+ out = sf.where(q_f == 0.0, -1.0, out)
1708
+ out = sf.where(q_f == 1.0, float(n_i), out)
1709
+ valid = (q_f > 0.0) & (q_f < 1.0)
1710
+ if not bool(sf.any(valid)):
1711
+ return out
1712
+ low = sf.where(q_f * 0 + 1, -1.0, -1.0)
1713
+ high = sf.where(q_f * 0 + 1, float(n_i), float(n_i))
1714
+ steps = int(np.ceil(np.log2(max(n_i + 2, 2)))) + 2
1715
+ for _ in range(max(1, steps)):
1716
+ mid = sf.floor((low + high) / 2.0)
1717
+ cdf_mid = self.cdf(mid, n_i, p_f, loc=0)
1718
+ move_right = valid & (cdf_mid < q_f)
1719
+ low = sf.where(move_right, mid, low)
1720
+ high = sf.where(valid & (~move_right), mid, high)
1721
+ k = sf.floor(high)
1722
+ cdf_k = self.cdf(k, n_i, p_f, loc=0)
1723
+ k = sf.where(valid & (cdf_k < q_f), k + 1.0, k)
1724
+ km1 = k - 1.0
1725
+ cdf_km1 = self.cdf(km1, n_i, p_f, loc=0)
1726
+ return sf.where(valid & (km1 >= -1.0) & (cdf_km1 >= q_f), km1, sf.where(valid, k, out))
1727
+
1728
+ def pmf(self, k, n, p, *, loc=0):
1729
+ sf = self._sf
1730
+ n_i = int(n)
1731
+ p_f = float(p)
1732
+ k_f = sf.as_float64(k) - float(loc)
1733
+ if n_i < 0 or p_f < 0.0 or p_f > 1.0:
1734
+ return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
1735
+ k_floor = sf.floor(k_f)
1736
+ is_int = (k_floor == k_f)
1737
+ valid = (k_floor >= 0.0) & (k_floor <= float(n_i)) & is_int
1738
+ k_safe = sf.clip(k_floor, 0.0, float(n_i))
1739
+ logcoef = sf.gammaln(n_i + 1.0) - sf.gammaln(k_safe + 1.0) - sf.gammaln(n_i - k_safe + 1.0)
1740
+ logpmf = logcoef + k_safe * sf.log(sf.maximum(p_f, 1e-300)) + (n_i - k_safe) * sf.log(sf.maximum(1.0 - p_f, 1e-300))
1741
+ return sf.where(valid, sf.exp(logpmf), 0.0)
1742
+
1743
+ def cdf(self, k, n, p, *, loc=0):
1744
+ sf = self._sf
1745
+ n_i = int(n)
1746
+ p_f = float(p)
1747
+ k_f = sf.as_float64(k) - float(loc)
1748
+ if n_i < 0 or p_f < 0.0 or p_f > 1.0:
1749
+ return sf.where(k_f * 0 + 1, float("nan"), float("nan"))
1750
+ k_floor = sf.floor(k_f)
1751
+ out = sf.where(k_floor < 0.0, 0.0, sf.betainc(n_i - k_floor, k_floor + 1.0, 1.0 - p_f))
1752
+ return sf.where(k_floor >= float(n_i), 1.0, out)
1753
+
1754
+ def sf(self, k, n, p, *, loc=0):
1755
+ return sf_safe_sub(1.0, self.cdf(k, n, p, loc=loc), self._sf)
1756
+
1757
+ def ppf(self, q, n, p, *, loc=0):
1758
+ sf = self._sf
1759
+ loc_f = float(loc)
1760
+ q_f = sf.as_float64(q)
1761
+ return self._ppf_search(q_f, n, p) + loc_f
1762
+
1763
+ def isf(self, q, n, p, *, loc=0):
1764
+ return self.ppf(1.0 - self._sf.as_float64(q), n, p, loc=loc)
1765
+
1766
+ def rvs(self, n, p, *, size=None, loc=0, dtype=None):
1767
+ return _rvs_binom(self._sf, n=n, p=p, size=size, loc=loc)
1768
+
1769
+
1770
+ # =============================================================================
1771
+ # Approximation-based inverse special functions
1772
+ # =============================================================================
1773
+ # These provide fast, vectorized (numpy/cupy/torch) initial guesses for
1774
+ # gammaincinv and betaincinv. Used by all three backends.
1775
+ # Reference:
1776
+ # - Wilson-Hilferty (1931) for chi2/gamma
1777
+ # - logit-normal approximation for beta (DiDonato & Morris 1996)
1778
+ # - Newton refinement for 1-2 extra correct digits per step
1779
+
1780
+ def _gammaincinv_wilson_hilferty(a, q):
1781
+ """Wilson-Hilferty cube-root approximation for gammaincinv(a, q).
1782
+
1783
+ Returns an initial guess x ≈ gammaincinv(a, q).
1784
+ Works for a > 0, q ∈ (0, 1). Best for a >= 1.
1785
+ """
1786
+ import scipy.special as _scsp
1787
+ a_f = float(a)
1788
+ c = 1.0 / (9.0 * a_f)
1789
+ s = math.sqrt(c)
1790
+ z = -math.sqrt(2.0) * _scsp.erfcinv(2.0 * np.asarray(q, dtype=np.float64))
1791
+ x = a_f * (1.0 - c + z * s) ** 3
1792
+ return np.where(x > 0, x, 1e-10)
1793
+
1794
+
1795
+ def _gammaincinv_a_small(a, q):
1796
+ """Approximation for gammaincinv(a, q) when a < 1.
1797
+
1798
+ Uses power series for small q and normal approximation for large q.
1799
+ """
1800
+ import scipy.special as _scsp
1801
+ q_arr = np.asarray(q, dtype=np.float64)
1802
+ a_f = float(a)
1803
+ g_a1 = math.exp(_scsp.gammaln(a_f + 1.0))
1804
+ # Series: P(a,x) ≈ x^a / Gamma(a+1) for small x → x ≈ (q * Gamma(a+1))^(1/a)
1805
+ x_small = (q_arr * g_a1) ** (1.0 / a_f)
1806
+ # For large q, use Wilson-Hilferty even though it's designed for a >= 1
1807
+ x_large = _gammaincinv_wilson_hilferty(a_f, q_arr)
1808
+ # Blend: use small approx when x_small < 1, large approx otherwise
1809
+ return np.where(x_small < 1.0, x_small, x_large)
1810
+
1811
+
1812
+ def _gammaincinv_newton_numpy(a, q, x0, n_iter=3):
1813
+ """Refine gammaincinv(a, q) using Newton's method.
1814
+
1815
+ x0: initial guess (numpy array)
1816
+ n_iter: number of Newton refinement steps (default 3, each gives ~1 extra digit)
1817
+ """
1818
+ import scipy.special as scsp
1819
+ x = np.asarray(x0, dtype=np.float64)
1820
+ a_f = float(a)
1821
+ log_ga = math.lgamma(a_f)
1822
+ q_arr = np.asarray(q, dtype=np.float64)
1823
+ for _ in range(n_iter):
1824
+ p = scsp.gammainc(a_f, x)
1825
+ diff = p - q_arr
1826
+ if np.max(np.abs(diff)) < 1e-14:
1827
+ break
1828
+ log_deriv = (a_f - 1.0) * np.log(np.clip(x, 1e-300, None)) - x - log_ga
1829
+ deriv = np.exp(log_deriv)
1830
+ deriv = np.clip(deriv, 1e-300, 1e300)
1831
+ x = x - diff / deriv
1832
+ x = np.clip(x, 1e-15, 1e6)
1833
+ return x
1834
+
1835
+
1836
+ def _betaincinv_logit_approx(a, b, q):
1837
+ """Logit-normal approximation for betaincinv(a, b, q).
1838
+
1839
+ For Beta(a, b), the logit transform log(X/(1-X)) is approximately normal
1840
+ with mean ψ(a) - ψ(b) and variance 1/a + 1/b (Digamma approximation).
1841
+ """
1842
+ import scipy.special as _scsp
1843
+ a_f, b_f = float(a), float(b)
1844
+ mu = _scsp.digamma(a_f) - _scsp.digamma(b_f)
1845
+ sigma2 = 1.0 / a_f + 1.0 / b_f
1846
+ sigma = math.sqrt(sigma2)
1847
+ z = -math.sqrt(2.0) * _scsp.erfcinv(2.0 * np.asarray(q, dtype=np.float64))
1848
+ logit_q = mu + sigma * z
1849
+ x = 1.0 / (1.0 + np.exp(-logit_q))
1850
+ return np.clip(x, 1e-15, 1.0 - 1e-15)
1851
+
1852
+
1853
+ def _betaincinv_newton_numpy(a, b, q, x0, n_iter=3):
1854
+ """Refine betaincinv(a, b, q) using Newton's method.
1855
+
1856
+ x0: initial guess (numpy array)
1857
+ n_iter: number of Newton refinement steps
1858
+ """
1859
+ import scipy.special as scsp
1860
+ x = np.asarray(x0, dtype=np.float64)
1861
+ a_f, b_f = float(a), float(b)
1862
+ log_beta = math.lgamma(a_f) + math.lgamma(b_f) - math.lgamma(a_f + b_f)
1863
+ q_arr = np.asarray(q, dtype=np.float64)
1864
+ for _ in range(n_iter):
1865
+ p = scsp.betainc(a_f, b_f, x)
1866
+ diff = p - q_arr
1867
+ if np.max(np.abs(diff)) < 1e-14:
1868
+ break
1869
+ log_deriv = (a_f - 1.0) * np.log(np.clip(x, 1e-300, None)) + \
1870
+ (b_f - 1.0) * np.log(np.clip(1.0 - x, 1e-300, None)) - log_beta
1871
+ deriv = np.exp(log_deriv)
1872
+ deriv = np.clip(deriv, 1e-300, 1e300)
1873
+ x = x - diff / deriv
1874
+ x = np.clip(x, 1e-15, 1.0 - 1e-15)
1875
+ return x
1876
+
1877
+
1878
+ def _t_ppf_cornish_fisher(df, q):
1879
+ """Cornish-Fisher expansion for Student-t quantile function.
1880
+
1881
+ Avoids the expensive betaincinv call.
1882
+ Accuracy: ~1e-10 for df >= 2, ~1e-6 for df < 2.
1883
+ """
1884
+ import scipy.special as _scsp
1885
+ z = -math.sqrt(2.0) * _scsp.erfcinv(2.0 * np.asarray(q, dtype=np.float64))
1886
+ z2 = z * z
1887
+ z3 = z2 * z
1888
+ z5 = z3 * z2
1889
+ df_f = float(df)
1890
+ # Hall (1992) approximation for t quantile
1891
+ d1 = 1.0 / (4.0 * df_f)
1892
+ d2 = 1.0 / (96.0 * df_f * df_f)
1893
+ d3 = 1.0 / (384.0 * df_f * df_f * df_f)
1894
+ d4 = 1.0 / (9216.0 * df_f * df_f * df_f)
1895
+ t_approx = z + (z3 + z) * d1 + (5.0 * z5 + 16.0 * z3 + 3.0 * z) * d2 + \
1896
+ (3.0 * z5 + 19.0 * z3 + 17.0 * z) * d3 + \
1897
+ (79.0 * z5 + 462.0 * z3 + 579.0 * z) * d4
1898
+ return np.asarray(t_approx)
1899
+
1900
+
1901
+ def _t_ppf_hall_approx(df, q):
1902
+ """Hall's (1992) approximation for t quantile.
1903
+
1904
+ More accurate than basic Cornish-Fisher, error ~1e-14 for df >= 1.
1905
+ Uses the inverse of the regularized incomplete beta via a
1906
+ transformed normal approximation.
1907
+ """
1908
+ import scipy.special as scsp
1909
+ df_f = float(df)
1910
+ # Fisher-Cornish expansion
1911
+ z = -math.sqrt(2.0) * scsp.erfcinv(2.0 * q)
1912
+ z2 = z * z
1913
+ z3 = z2 * z
1914
+ z4 = z3 * z
1915
+ z5 = z4 * z
1916
+
1917
+ # Coefficients from Hall (1992) Biometrika
1918
+ a1 = 1.0 / 4.0
1919
+ a2 = 1.0 / 96.0
1920
+ a3 = -1.0 / 96.0
1921
+ a4 = -1.0 / 384.0
1922
+
1923
+ nu = df_f
1924
+ t = z + (z3 + z) * a1 / nu + \
1925
+ (5.0 * z5 + 16.0 * z3 + 3.0 * z) * a2 / (nu * nu) + \
1926
+ (3.0 * z5 + 19.0 * z3 + 17.0 * z) * a3 / (nu * nu * nu) + \
1927
+ (79.0 * z5 + 462.0 * z3 + 579.0 * z) * a4 / (nu * nu * nu * nu)
1928
+ return np.asarray(t)
1929
+
1930
+
1931
+ def _t_ppf_wilson_hilferty_approx(df, q):
1932
+ """Wilson-Hilferty-type approximation for t PPF.
1933
+
1934
+ Uses the relationship t^2 ~ df * F(1, df) and approximates the
1935
+ F quantile via chi2 approximation.
1936
+ Best for |z| < 5 and df > 1.
1937
+ """
1938
+ import scipy.special as scsp
1939
+ df_f = float(df)
1940
+ q2 = q # keep signed
1941
+ # For signed quantiles, work with |z| and restore sign
1942
+ sign = np.sign(q2 - 0.5)
1943
+ sign = np.where(sign == 0, 1.0, sign)
1944
+ q_abs = np.abs(q2 - 0.5) + 0.5 # always in (0.5, 1]
1945
+
1946
+ # z = Φ^{-1}(q)
1947
+ z = -math.sqrt(2.0) * scsp.erfcinv(2.0 * q_abs)
1948
+ z = z * sign
1949
+
1950
+ # Refinement: t ≈ z * (1 - 1/(4*df) + z^2/(96*df^2))^{-1/2} ...
1951
+ # This is a simplified version of the Hall approximation
1952
+ z2 = z * z
1953
+ t = z * (1.0 + (z2 - 1.0) / (4.0 * df_f) + (5.0 * z2 * (z2 + 7.0) - 2.0) / (96.0 * df_f * df_f))
1954
+ return np.asarray(t)
1955
+
1956
+
1957
+ # =============================================================================
1958
+ # Scalar-function helpers (atan, log, log1p, square, abs, power, floor)
1959
+ # =============================================================================
1960
+ # These need per-backend implementations. We store them on the sf objects
1961
+ # but provide fallbacks for protocols that don't define them.
1962
+
1963
+ def _scalar_op(sf, name, *args):
1964
+ """Call a scalar operation, falling back to numpy if not on sf."""
1965
+ fn = getattr(sf, name, None)
1966
+ if fn is not None:
1967
+ return fn(*args)
1968
+ np_fn = getattr(np, name)
1969
+ return np_fn(*[np.asarray(a) for a in args])
1970
+
1971
+
1972
+ class _SpecialFunctionsMixin:
1973
+ """Mixin adding scalar ops to the three concrete SpecialFunctions impls."""
1974
+
1975
+ def sqrt(self, x):
1976
+ return _scalar_op(type(self), "sqrt", x)
1977
+
1978
+ def log(self, x):
1979
+ return _scalar_op(type(self), "log", x)
1980
+
1981
+ def log1p(self, x):
1982
+ return _scalar_op(type(self), "log1p", x)
1983
+
1984
+ def square(self, x):
1985
+ return _scalar_op(type(self), "square", x)
1986
+
1987
+ def abs(self, x):
1988
+ return _scalar_op(type(self), "abs", x)
1989
+
1990
+ def power(self, x, y):
1991
+ return _scalar_op(type(self), "power", x, y)
1992
+
1993
+ def floor(self, x):
1994
+ return _scalar_op(type(self), "floor", x)
1995
+
1996
+ def atan(self, x):
1997
+ return _scalar_op(type(self), "arctan", x)
1998
+
1999
+ def exp(self, x):
2000
+ return _scalar_op(type(self), "exp", x)
2001
+
2002
+ def maximum(self, x, y):
2003
+ return _scalar_op(type(self), "maximum", x, y)
2004
+
2005
+ def minimum(self, x, y):
2006
+ return _scalar_op(type(self), "minimum", x, y)
2007
+
2008
+ def any(self, x):
2009
+ return _scalar_op(type(self), "any", x)
2010
+
2011
+ def to_numpy(self, x):
2012
+ return np.asarray(x)
2013
+
2014
+
2015
+ # Patch scalar ops into each concrete implementation
2016
+ for _cls in (CuPySpecialFunctions, TorchSpecialFunctions, ScipySpecialFunctions):
2017
+ for _name in ("sqrt", "log", "log1p", "square", "abs", "power", "floor", "atan", "tan", "exp", "maximum", "minimum", "any", "to_numpy"):
2018
+ _np_name = {"atan": "arctan", "tan": "tan"}.get(_name, _name)
2019
+ if _name in ("power", "maximum", "minimum"):
2020
+ def _make_bin(_n=_np_name):
2021
+ return lambda self, x, y: getattr(np, _n)(np.asarray(x), np.asarray(y))
2022
+ setattr(_cls, _name, _make_bin())
2023
+ elif _name == "any":
2024
+ def _make_any():
2025
+ return lambda self, x: np.any(np.asarray(x))
2026
+ setattr(_cls, _name, _make_any())
2027
+ else:
2028
+ def _make_fn(_n=_np_name):
2029
+ return lambda self, x: getattr(np, _n)(np.asarray(x))
2030
+ setattr(_cls, _name, _make_fn())
2031
+
2032
+ # Now override with backend-native versions
2033
+ # Fix to_numpy for ScipySpecialFunctions (np.to_numpy doesn't exist)
2034
+ ScipySpecialFunctions.to_numpy = lambda self, x: np.asarray(x)
2035
+
2036
+ CuPySpecialFunctions.sqrt = lambda self, x: self._cp.sqrt(self._cp.asarray(x, dtype=self._cp.float64))
2037
+ CuPySpecialFunctions.log = lambda self, x: self._cp.log(self._cp.asarray(x, dtype=self._cp.float64))
2038
+ CuPySpecialFunctions.log1p = lambda self, x: self._cp.log1p(self._cp.asarray(x, dtype=self._cp.float64))
2039
+ CuPySpecialFunctions.square = lambda self, x: self._cp.square(self._cp.asarray(x, dtype=self._cp.float64))
2040
+ CuPySpecialFunctions.abs = lambda self, x: self._cp.abs(self._cp.asarray(x, dtype=self._cp.float64))
2041
+ CuPySpecialFunctions.power = lambda self, x, y: self._cp.power(self._cp.asarray(x, dtype=self._cp.float64), self._cp.asarray(y, dtype=self._cp.float64))
2042
+ CuPySpecialFunctions.floor = lambda self, x: self._cp.floor(self._cp.asarray(x, dtype=self._cp.float64))
2043
+ CuPySpecialFunctions.atan = lambda self, x: self._cp.arctan(self._cp.asarray(x, dtype=self._cp.float64))
2044
+ CuPySpecialFunctions.tan = lambda self, x: self._cp.tan(self._cp.asarray(x, dtype=self._cp.float64))
2045
+ CuPySpecialFunctions.exp = lambda self, x: self._cp.exp(self._cp.asarray(x, dtype=self._cp.float64))
2046
+ CuPySpecialFunctions.maximum = lambda self, x, y: self._cp.maximum(self._cp.asarray(x, dtype=self._cp.float64), self._cp.asarray(y, dtype=self._cp.float64))
2047
+ CuPySpecialFunctions.minimum = lambda self, x, y: self._cp.minimum(self._cp.asarray(x, dtype=self._cp.float64), self._cp.asarray(y, dtype=self._cp.float64))
2048
+ CuPySpecialFunctions.any = lambda self, x: self._cp.any(x)
2049
+ CuPySpecialFunctions.to_numpy = lambda self, x: self._cp.asnumpy(x) if hasattr(x, 'get') else np.asarray(x)
2050
+
2051
+ TorchSpecialFunctions.sqrt = lambda self, x: self._torch.sqrt(self._as_tensor(x))
2052
+ TorchSpecialFunctions.log = lambda self, x: self._torch.log(self._as_tensor(x))
2053
+ TorchSpecialFunctions.log1p = lambda self, x: self._torch.log1p(self._as_tensor(x))
2054
+ TorchSpecialFunctions.square = lambda self, x: self._torch.square(self._as_tensor(x))
2055
+ TorchSpecialFunctions.abs = lambda self, x: self._torch.abs(self._as_tensor(x))
2056
+ TorchSpecialFunctions.power = lambda self, x, y: self._torch.pow(self._as_tensor(x), self._as_tensor(y))
2057
+ TorchSpecialFunctions.floor = lambda self, x: self._torch.floor(self._as_tensor(x))
2058
+ TorchSpecialFunctions.atan = lambda self, x: self._torch.atan(self._as_tensor(x))
2059
+ TorchSpecialFunctions.tan = lambda self, x: self._torch.tan(self._as_tensor(x))
2060
+ TorchSpecialFunctions.exp = lambda self, x: self._torch.exp(self._as_tensor(x))
2061
+ TorchSpecialFunctions.maximum = lambda self, x, y: self._torch.maximum(self._as_tensor(x), self._as_tensor(y))
2062
+ TorchSpecialFunctions.minimum = lambda self, x, y: self._torch.minimum(self._as_tensor(x), self._as_tensor(y))
2063
+ TorchSpecialFunctions.any = lambda self, x: self._torch.any(x)
2064
+ TorchSpecialFunctions.to_numpy = lambda self, x: x.detach().cpu().numpy() if hasattr(x, 'cpu') else np.asarray(x)
2065
+
2066
+
2067
+ # =============================================================================
2068
+ # Safe scalar operations (clip 1-x to [0,1] for SF computation)
2069
+ # =============================================================================
2070
+
2071
+ def sf_safe_sub(val, other, sf):
2072
+ """Compute val - other, clamping result to [0, 1]."""
2073
+ return sf.clip(val - other, 0.0, 1.0)
2074
+
2075
+
2076
+ def sf_safe_mul(val, factor, sf):
2077
+ """Compute val * factor, clamping result to [0, 1]."""
2078
+ return sf.clip(val * factor, 0.0, 1.0)
2079
+
2080
+
2081
+ # =============================================================================
2082
+ # Random variate helpers (pure-numpy CPU generation, then converted)
2083
+ # =============================================================================
2084
+
2085
+ def _rvs_normal(sf, *, size, loc, scale):
2086
+ out = np.random.normal(loc=float(loc), scale=float(scale), size=size)
2087
+ if hasattr(sf, 'as_float64'):
2088
+ return sf.as_float64(out)
2089
+ return out
2090
+
2091
+
2092
+ def _rvs_t(sf, *, df, size, loc, scale):
2093
+ # Use scipy fallback for t-distribution rvs
2094
+ import scipy.stats as sps
2095
+ out = sps.t.rvs(df=float(df), size=size, loc=float(loc), scale=float(scale))
2096
+ if hasattr(sf, 'as_float64'):
2097
+ return sf.as_float64(out)
2098
+ return out
2099
+
2100
+
2101
+ def _rvs_uniform(sf, *, size, loc, scale):
2102
+ out = np.random.uniform(low=float(loc), high=float(loc) + float(scale), size=size)
2103
+ if hasattr(sf, 'as_float64'):
2104
+ return sf.as_float64(out)
2105
+ return out
2106
+
2107
+
2108
+ def _rvs_expon(sf, *, size, loc, scale):
2109
+ out = float(loc) + np.random.exponential(scale=float(scale), size=size)
2110
+ if hasattr(sf, 'as_float64'):
2111
+ return sf.as_float64(out)
2112
+ return out
2113
+
2114
+
2115
+ def _rvs_cauchy(sf, *, size, loc, scale):
2116
+ u = np.random.random(size=size)
2117
+ out = float(loc) + float(scale) * np.tan(np.pi * (u - 0.5))
2118
+ if hasattr(sf, 'as_float64'):
2119
+ return sf.as_float64(out)
2120
+ return out
2121
+
2122
+
2123
+ def _rvs_laplace(sf, *, size, loc, scale):
2124
+ out = np.random.laplace(loc=float(loc), scale=float(scale), size=size)
2125
+ if hasattr(sf, 'as_float64'):
2126
+ return sf.as_float64(out)
2127
+ return out
2128
+
2129
+
2130
+ def _rvs_logistic(sf, *, size, loc, scale):
2131
+ u = np.random.random(size=size)
2132
+ out = float(loc) + float(scale) * np.log(u / (1.0 - u))
2133
+ if hasattr(sf, 'as_float64'):
2134
+ return sf.as_float64(out)
2135
+ return out
2136
+
2137
+
2138
+ def _rvs_chi2(sf, *, df, size):
2139
+ out = np.random.chisquare(df=float(df), size=size)
2140
+ if hasattr(sf, 'as_float64'):
2141
+ return sf.as_float64(out)
2142
+ return out
2143
+
2144
+
2145
+ def _rvs_gamma(sf, *, a, size, loc, scale):
2146
+ out = float(loc) + np.random.gamma(shape=float(a), scale=float(scale), size=size)
2147
+ if hasattr(sf, 'as_float64'):
2148
+ return sf.as_float64(out)
2149
+ return out
2150
+
2151
+
2152
+ def _rvs_beta(sf, *, a, b, size, loc, scale):
2153
+ out = float(loc) + float(scale) * np.random.beta(float(a), float(b), size=size)
2154
+ if hasattr(sf, 'as_float64'):
2155
+ return sf.as_float64(out)
2156
+ return out
2157
+
2158
+
2159
+ def _rvs_f(sf, *, dfn, dfd, size):
2160
+ out = np.random.f(dfn=float(dfn), dfd=float(dfd), size=size)
2161
+ if hasattr(sf, 'as_float64'):
2162
+ return sf.as_float64(out)
2163
+ return out
2164
+
2165
+
2166
+ def _rvs_weibull(sf, *, c, size, loc, scale):
2167
+ out = float(loc) + float(scale) * np.random.weibull(a=float(c), size=size)
2168
+ if hasattr(sf, 'as_float64'):
2169
+ return sf.as_float64(out)
2170
+ return out
2171
+
2172
+
2173
+ def _rvs_lognorm(sf, *, s, size, loc, scale):
2174
+ out = float(loc) + float(scale) * np.exp(float(s) * np.random.normal(size=size))
2175
+ if hasattr(sf, 'as_float64'):
2176
+ return sf.as_float64(out)
2177
+ return out
2178
+
2179
+
2180
+ def _rvs_poisson(sf, *, mu, size, loc):
2181
+ out = np.random.poisson(lam=float(mu), size=size) + int(loc)
2182
+ if hasattr(sf, 'as_float64'):
2183
+ return sf.as_float64(out)
2184
+ return out
2185
+
2186
+
2187
+ def _rvs_binom(sf, *, n, p, size, loc):
2188
+ out = np.random.binomial(n=int(n), p=float(p), size=size) + int(loc)
2189
+ if hasattr(sf, 'as_float64'):
2190
+ return sf.as_float64(out)
2191
+ return out
2192
+
2193
+
2194
+ # =============================================================================
2195
+ # Factory
2196
+ # =============================================================================
2197
+
2198
+ _DISTRIBUTION_FACTORIES = {
2199
+ "norm": lambda sf: NormDistributionBase(sf),
2200
+ "t": lambda sf: TDistributionBase(sf),
2201
+ "uniform": lambda sf: UniformDistributionBase(sf),
2202
+ "expon": lambda sf: ExponDistributionBase(sf),
2203
+ "cauchy": lambda sf: CauchyDistributionBase(sf),
2204
+ "laplace": lambda sf: LaplaceDistributionBase(sf),
2205
+ "logistic": lambda sf: LogisticDistributionBase(sf),
2206
+ "chi2": lambda sf: Chi2DistributionBase(sf),
2207
+ "gamma": lambda sf: GammaDistributionBase(sf),
2208
+ "beta": lambda sf: BetaDistributionBase(sf),
2209
+ "f": lambda sf: FDistributionBase(sf),
2210
+ "weibull_min": lambda sf: WeibullMinDistributionBase(sf),
2211
+ "lognorm": lambda sf: LognormDistributionBase(sf, NormDistributionBase(sf)),
2212
+ "poisson": lambda sf: PoissonDistributionBase(sf),
2213
+ "binom": lambda sf: BinomDistributionBase(sf),
2214
+ }
2215
+
2216
+ _NATIVE_NAMES = sorted(_DISTRIBUTION_FACTORIES.keys())
2217
+
2218
+
2219
+ def _make_sf(backend: str, device: str | None = None, *, use_lut: bool = True) -> SpecialFunctions:
2220
+ """Create a SpecialFunctions instance for the given backend name."""
2221
+ if backend == "numpy":
2222
+ return ScipySpecialFunctions(use_lut=use_lut)
2223
+ if backend == "cupy":
2224
+ return CuPySpecialFunctions(use_lut=use_lut)
2225
+ if backend == "torch":
2226
+ return TorchSpecialFunctions(device=device, use_lut=use_lut)
2227
+ raise ValueError(f"Unsupported backend: {backend}")
2228
+
2229
+
2230
+ def get_distribution(name: str, backend: str = "auto", device: str | None = None, *, use_lut: bool = True):
2231
+ """Get a distribution object for the given backend.
2232
+
2233
+ Parameters
2234
+ ----------
2235
+ name : str
2236
+ Distribution name (e.g. ``'norm'``, ``'t'``, ``'chi2'``).
2237
+ backend : {'auto', 'numpy', 'cupy', 'torch'}, default='auto'
2238
+ Which backend to use. ``'auto'`` picks the first available GPU
2239
+ backend (cupy > torch) or falls back to numpy.
2240
+ device : str, optional
2241
+ Torch device string (e.g. ``'cuda'``, ``'cuda:0'``, ``'cpu'``).
2242
+ Only used when backend is ``'torch'``.
2243
+ use_lut : bool, default=True
2244
+ Use LUT cache + 1-step Newton refinement for inverse special functions
2245
+ (``betaincinv``, ``gammaincinv``). When ``False``, falls back to the
2246
+ full iterative solver (scipy for numpy, Newton-Raphson for torch).
2247
+ ``True`` gives 10-500x speedup for ``t.ppf``/``f.ppf`` on GPU,
2248
+ with negligible accuracy loss (LUT is built from scipy reference values).
2249
+
2250
+ Returns
2251
+ -------
2252
+ Distribution object with methods: cdf, sf, ppf, isf, pdf, rvs, etc.
2253
+ """
2254
+ if backend == "auto":
2255
+ if CuPySpecialFunctions is not None: # always importable if cupy installed
2256
+ try:
2257
+ return get_distribution(name, backend="cupy", device=device, use_lut=use_lut)
2258
+ except Exception:
2259
+ pass
2260
+ try:
2261
+ return get_distribution(name, backend="torch", device=device, use_lut=use_lut)
2262
+ except Exception:
2263
+ pass
2264
+ backend = "numpy"
2265
+
2266
+ sf = _make_sf(backend, device, use_lut=use_lut)
2267
+ factory = _DISTRIBUTION_FACTORIES.get(name)
2268
+ if factory is None:
2269
+ # Try case-insensitive
2270
+ factory = _DISTRIBUTION_FACTORIES.get(name.lower())
2271
+ if factory is None:
2272
+ raise ValueError(f"Unknown distribution: {name}")
2273
+ return factory(sf)
2274
+
2275
+
2276
+ def list_available_distributions():
2277
+ """List all natively implemented distribution names."""
2278
+ return list(_NATIVE_NAMES)
2279
+
2280
+
2281
+ # =============================================================================
2282
+ # DistributionProxy — module-level lazy singletons
2283
+ # =============================================================================
2284
+
2285
+ class DistributionProxy:
2286
+ """Lazy proxy that resolves the backend on each call.
2287
+
2288
+ Supports ``backend=`` keyword override::
2289
+
2290
+ norm.cdf(x) # auto backend
2291
+ norm.cdf(x, backend="torch") # force torch
2292
+ """
2293
+
2294
+ def __init__(self, name: str, default_backend: str = "auto", device: str | None = None, *, use_lut: bool = True):
2295
+ self._name = name
2296
+ self._default_backend = default_backend
2297
+ self._device = device
2298
+ self._use_lut = use_lut
2299
+
2300
+ def _resolve(self, kwargs, *arrays):
2301
+ from statgpu.backends import _is_torch_array, _resolve_backend
2302
+
2303
+ backend = kwargs.pop("backend", self._default_backend)
2304
+ device = kwargs.pop("device", self._device)
2305
+ use_lut = kwargs.pop("use_lut", self._use_lut)
2306
+ if backend == "auto":
2307
+ backend = _resolve_backend("auto", *arrays, *kwargs.values())
2308
+ if backend == "torch" and device is None:
2309
+ for arr in (*arrays, *kwargs.values()):
2310
+ if _is_torch_array(arr):
2311
+ device = str(arr.device)
2312
+ break
2313
+ return get_distribution(self._name, backend=backend, device=device, use_lut=use_lut)
2314
+
2315
+ def __repr__(self):
2316
+ return (f"DistributionProxy({self._name!r}, "
2317
+ f"backend={self._default_backend!r}, "
2318
+ f"use_lut={self._use_lut!r})")
2319
+
2320
+ def cdf(self, x, **kw):
2321
+ return self._resolve(kw, x).cdf(x, **kw)
2322
+
2323
+ def sf(self, x, **kw):
2324
+ return self._resolve(kw, x).sf(x, **kw)
2325
+
2326
+ def ppf(self, q, **kw):
2327
+ return self._resolve(kw, q).ppf(q, **kw)
2328
+
2329
+ def isf(self, q, **kw):
2330
+ return self._resolve(kw, q).isf(q, **kw)
2331
+
2332
+ def pdf(self, x, **kw):
2333
+ return self._resolve(kw, x).pdf(x, **kw)
2334
+
2335
+ def pmf(self, k, **kw):
2336
+ return self._resolve(kw, k).pmf(k, **kw)
2337
+
2338
+ def rvs(self, **kw):
2339
+ return self._resolve(kw, *kw.values()).rvs(**kw)
2340
+
2341
+ def two_sided_pvalue(self, stat_abs, **kw):
2342
+ return self._resolve(kw, stat_abs).two_sided_pvalue(stat_abs, **kw)
2343
+
2344
+ def two_sided_critical_value(self, alpha, **kw):
2345
+ return self._resolve(kw, alpha).two_sided_critical_value(alpha, **kw)
2346
+
2347
+
2348
+ # Module-level singletons (lazy, backend resolved per-call)
2349
+ norm = DistributionProxy("norm")
2350
+ t = DistributionProxy("t")
2351
+ uniform = DistributionProxy("uniform")
2352
+ expon = DistributionProxy("expon")
2353
+ cauchy = DistributionProxy("cauchy")
2354
+ laplace = DistributionProxy("laplace")
2355
+ logistic = DistributionProxy("logistic")
2356
+ chi2 = DistributionProxy("chi2")
2357
+ gamma = DistributionProxy("gamma")
2358
+ beta = DistributionProxy("beta")
2359
+ f = DistributionProxy("f")
2360
+ weibull_min = DistributionProxy("weibull_min")
2361
+ lognorm = DistributionProxy("lognorm")
2362
+ poisson = DistributionProxy("poisson")
2363
+ binom = DistributionProxy("binom")
2364
+
2365
+
2366
+ # Backward-compatible aliases (old CuPy-specific class names)
2367
+ NormDistributionGPU = NormDistributionBase
2368
+ TDistributionGPU = TDistributionBase
2369
+ UniformDistributionGPU = UniformDistributionBase
2370
+ ExponDistributionGPU = ExponDistributionBase
2371
+ CauchyDistributionGPU = CauchyDistributionBase
2372
+ LaplaceDistributionGPU = LaplaceDistributionBase
2373
+ LogisticDistributionGPU = LogisticDistributionBase
2374
+ Chi2DistributionGPU = Chi2DistributionBase
2375
+ GammaDistributionGPU = GammaDistributionBase
2376
+ BetaDistributionGPU = BetaDistributionBase
2377
+ FDistributionGPU = FDistributionBase
2378
+ WeibullMinDistributionGPU = WeibullMinDistributionBase
2379
+ LognormDistributionGPU = LognormDistributionBase
2380
+ PoissonDistributionGPU = PoissonDistributionBase
2381
+ BinomDistributionGPU = BinomDistributionBase
2382
+
2383
+
2384
+ def get_distribution_gpu(name: str, *, allow_fallback: bool = False):
2385
+ """Backward-compatible wrapper: get GPU distribution by name.
2386
+
2387
+ Delegates to the unified factory, defaulting to the best GPU backend.
2388
+ """
2389
+ import scipy.stats as sps
2390
+
2391
+ key = str(name).strip()
2392
+ if key.lower() in _DISTRIBUTION_FACTORIES:
2393
+ return get_distribution(key.lower(), backend="auto")
2394
+
2395
+ if allow_fallback:
2396
+ if hasattr(sps, key.lower()) or hasattr(sps, key):
2397
+ return ScipyFallbackDistribution(key.lower() if hasattr(sps, key.lower()) else key)
2398
+
2399
+ if hasattr(sps, key.lower()) or hasattr(sps, key):
2400
+ raise ValueError(
2401
+ f"Distribution '{name}' has no native GPU implementation. "
2402
+ "Set allow_fallback=True for explicit SciPy fallback."
2403
+ )
2404
+ raise ValueError(f"Unknown scipy.stats distribution: {name}")
2405
+
2406
+
2407
+ def list_available_distributions_gpu(include_scipy: bool = True):
2408
+ """Backward-compatible: list available distribution names."""
2409
+ native = list_available_distributions()
2410
+ if not include_scipy:
2411
+ return native
2412
+
2413
+ import scipy.stats as sps
2414
+ from scipy.stats import rv_continuous, rv_discrete
2415
+
2416
+ scipy_names = []
2417
+ for n in dir(sps):
2418
+ if n.startswith("_"):
2419
+ continue
2420
+ try:
2421
+ obj = getattr(sps, n)
2422
+ except Exception:
2423
+ continue
2424
+ if isinstance(obj, (rv_continuous, rv_discrete)):
2425
+ scipy_names.append(n)
2426
+ return sorted(set(native + scipy_names))
2427
+
2428
+
2429
+ class ScipyFallbackDistribution:
2430
+ """Dynamic scipy.stats distribution wrapper returning GPU-backed outputs."""
2431
+
2432
+ def __init__(self, name: str):
2433
+ self.name = str(name)
2434
+
2435
+ def __repr__(self):
2436
+ return f"ScipyFallbackDistribution('{self.name}')"
2437
+
2438
+ def _call(self, method_name, *args, **kwargs):
2439
+ import scipy.stats as sps
2440
+ dist = getattr(sps, self.name)
2441
+ method = getattr(dist, method_name)
2442
+ # Convert any GPU arrays to numpy for scipy
2443
+ np_args = []
2444
+ for v in args:
2445
+ if hasattr(v, "get"):
2446
+ np_args.append(v.get())
2447
+ elif hasattr(v, "detach"):
2448
+ np_args.append(v.detach().cpu().numpy())
2449
+ else:
2450
+ np_args.append(v)
2451
+ np_kw = {}
2452
+ for k, v in kwargs.items():
2453
+ if hasattr(v, "get"):
2454
+ np_kw[k] = v.get()
2455
+ elif hasattr(v, "detach"):
2456
+ np_kw[k] = v.detach().cpu().numpy()
2457
+ else:
2458
+ np_kw[k] = v
2459
+ result = method(*np_args, **np_kw)
2460
+ # Try to convert result back to GPU if default backend is GPU
2461
+ try:
2462
+ from statgpu.backends import get_backend
2463
+ backend = get_backend()
2464
+ if backend.name != "numpy":
2465
+ return backend.asarray(result)
2466
+ except Exception:
2467
+ pass
2468
+ return result
2469
+
2470
+ def cdf(self, x, *shape_args, **kwargs):
2471
+ return self._call("cdf", x, *shape_args, **kwargs)
2472
+
2473
+ def sf(self, x, *shape_args, **kwargs):
2474
+ return self._call("sf", x, *shape_args, **kwargs)
2475
+
2476
+ def ppf(self, q, *shape_args, **kwargs):
2477
+ return self._call("ppf", q, *shape_args, **kwargs)
2478
+
2479
+ def isf(self, q, *shape_args, **kwargs):
2480
+ return self._call("isf", q, *shape_args, **kwargs)
2481
+
2482
+ def pdf(self, x, *shape_args, **kwargs):
2483
+ return self._call("pdf", x, *shape_args, **kwargs)
2484
+
2485
+ def pmf(self, x, *shape_args, **kwargs):
2486
+ return self._call("pmf", x, *shape_args, **kwargs)
2487
+
2488
+ def rvs(self, *shape_args, size=None, dtype=None, **kwargs):
2489
+ out = self._call("rvs", *shape_args, size=size, **kwargs)
2490
+ if dtype is not None and hasattr(out, "astype"):
2491
+ out = out.astype(dtype)
2492
+ return out
2493
+
2494
+
2495
+ # =============================================================================
2496
+ # Backward-compatible special function aliases (for old consumers)
2497
+ # =============================================================================
2498
+
2499
+ def regularized_betainc_gpu(a, b, x):
2500
+ """Backward-compatible alias: use get_distribution for new code."""
2501
+ sf = CuPySpecialFunctions()
2502
+ return sf.betainc(a, b, x)
2503
+
2504
+
2505
+ def regularized_betaincinv_gpu(a, b, y):
2506
+ """Backward-compatible alias."""
2507
+ sf = CuPySpecialFunctions()
2508
+ return sf.betaincinv(a, b, y)
2509
+
2510
+
2511
+ def gammainc_gpu(a, x):
2512
+ """Backward-compatible alias."""
2513
+ sf = CuPySpecialFunctions()
2514
+ return sf.gammainc(a, x)
2515
+
2516
+
2517
+ def gammaincc_gpu(a, x):
2518
+ """Backward-compatible alias."""
2519
+ sf = CuPySpecialFunctions()
2520
+ return sf.gammaincc(a, x)
2521
+
2522
+
2523
+ def gammaincinv_gpu(a, q):
2524
+ """Backward-compatible alias."""
2525
+ sf = CuPySpecialFunctions()
2526
+ return sf.gammaincinv(a, q)
2527
+
2528
+
2529
+ def gammaln_gpu(x):
2530
+ """Backward-compatible alias."""
2531
+ sf = CuPySpecialFunctions()
2532
+ return sf.gammaln(x)
2533
+
2534
+
2535
+ # =============================================================================
2536
+ # Legacy distribution-function names (R-style)
2537
+ # =============================================================================
2538
+
2539
+ _LEGACY_DISTRIBUTION_FUNCTION_NAMES = {
2540
+ "t_cdf_gpu", "t_sf_gpu", "t_ppf_gpu", "t_two_sided_pvalue_gpu",
2541
+ "t_two_sided_critical_value_gpu", "norm_cdf_gpu", "norm_sf_gpu",
2542
+ "norm_ppf_gpu", "norm_isf_gpu", "norm_two_sided_pvalue_gpu",
2543
+ "norm_two_sided_critical_value_gpu", "rnorm_gpu", "dnorm_gpu",
2544
+ "dt_gpu", "rt_gpu", "pnorm_gpu", "qnorm_gpu", "pt_gpu", "qt_gpu",
2545
+ "dchisq_gpu", "pchisq_gpu", "qchisq_gpu", "rchisq_gpu",
2546
+ "dgamma_gpu", "pgamma_gpu", "qgamma_gpu", "rgamma_gpu",
2547
+ "dbeta_gpu", "pbeta_gpu", "qbeta_gpu", "rbeta_gpu",
2548
+ "df_gpu", "pf_gpu", "qf_gpu", "rf_gpu",
2549
+ "dpois_gpu", "ppois_gpu", "qpois_gpu", "rpois_gpu",
2550
+ "dbinom_gpu", "pbinom_gpu", "qbinom_gpu", "rbinom_gpu",
2551
+ }
2552
+
2553
+
2554
+ def __getattr__(name):
2555
+ """Lazy access to legacy distribution functions."""
2556
+ if name.startswith("_"):
2557
+ raise AttributeError(f"module {__name__} has no attribute {name}")
2558
+ if name in _LEGACY_DISTRIBUTION_FUNCTION_NAMES:
2559
+ from statgpu.linear_model.legacy import _distributions_legacy_gpu as legacy
2560
+ return getattr(legacy, name)
2561
+ try:
2562
+ return get_distribution_gpu(name)
2563
+ except Exception as exc:
2564
+ raise AttributeError(f"module {__name__} has no attribute {name}") from exc
2565
+
2566
+
2567
+ # =============================================================================
2568
+ # Exports
2569
+ # =============================================================================
2570
+
2571
+ __all__ = [
2572
+ # Core
2573
+ "get_distribution",
2574
+ "list_available_distributions",
2575
+ "DistributionProxy",
2576
+ "SpecialFunctions",
2577
+ # Backends
2578
+ "CuPySpecialFunctions",
2579
+ "TorchSpecialFunctions",
2580
+ "ScipySpecialFunctions",
2581
+ # Distributions
2582
+ "NormDistributionBase",
2583
+ "TDistributionBase",
2584
+ "UniformDistributionBase",
2585
+ "ExponDistributionBase",
2586
+ "CauchyDistributionBase",
2587
+ "LaplaceDistributionBase",
2588
+ "LogisticDistributionBase",
2589
+ "Chi2DistributionBase",
2590
+ "GammaDistributionBase",
2591
+ "BetaDistributionBase",
2592
+ "FDistributionBase",
2593
+ "WeibullMinDistributionBase",
2594
+ "LognormDistributionBase",
2595
+ "PoissonDistributionBase",
2596
+ "BinomDistributionBase",
2597
+ # Module-level proxies
2598
+ "norm", "t", "uniform", "expon", "cauchy", "laplace",
2599
+ "logistic", "chi2", "gamma", "beta", "f",
2600
+ "weibull_min", "lognorm", "poisson", "binom",
2601
+ # Backward compat
2602
+ "NormDistributionGPU", "TDistributionGPU", "UniformDistributionGPU",
2603
+ "ExponDistributionGPU", "CauchyDistributionGPU", "LaplaceDistributionGPU",
2604
+ "LogisticDistributionGPU", "Chi2DistributionGPU", "GammaDistributionGPU",
2605
+ "BetaDistributionGPU", "FDistributionGPU", "WeibullMinDistributionGPU",
2606
+ "LognormDistributionGPU", "PoissonDistributionGPU", "BinomDistributionGPU",
2607
+ "ScipyFallbackDistribution",
2608
+ "get_distribution_gpu",
2609
+ "list_available_distributions_gpu",
2610
+ ]