pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
# Copyright 2023 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pymc as pm
|
|
17
|
+
|
|
18
|
+
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
|
|
19
|
+
from pymc.distributions.shape_utils import rv_size_is_none
|
|
20
|
+
from pytensor import tensor as pt
|
|
21
|
+
from pytensor.tensor.random.op import RandomVariable
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def log1mexp(x):
|
|
25
|
+
cond = x < np.log(0.5)
|
|
26
|
+
return np.piecewise(
|
|
27
|
+
x,
|
|
28
|
+
[cond, ~cond],
|
|
29
|
+
[lambda x: np.log1p(-np.exp(x)), lambda x: np.log(-np.expm1(x))],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GeneralizedPoissonRV(RandomVariable):
|
|
34
|
+
name = "generalized_poisson"
|
|
35
|
+
signature = "(),()->()"
|
|
36
|
+
dtype = "int64"
|
|
37
|
+
_print_name = ("GeneralizedPoisson", "\\operatorname{GeneralizedPoisson}")
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def rng_fn(cls, rng, theta, lam, size):
|
|
41
|
+
theta = np.asarray(theta)
|
|
42
|
+
lam = np.asarray(lam)
|
|
43
|
+
|
|
44
|
+
if size is not None:
|
|
45
|
+
dist_size = size
|
|
46
|
+
else:
|
|
47
|
+
dist_size = np.broadcast_shapes(theta.shape, lam.shape)
|
|
48
|
+
|
|
49
|
+
# A mix of 2 algorithms described by Famoye (1997) is used depending on parameter values
|
|
50
|
+
# 0: Inverse method, computed on the log scale. Used when lam <= 0.
|
|
51
|
+
# 1: Branching method. Used when lambda > 0.
|
|
52
|
+
x = np.empty(dist_size)
|
|
53
|
+
idxs_mask = np.broadcast_to(lam < 0, dist_size)
|
|
54
|
+
if np.any(idxs_mask):
|
|
55
|
+
x[idxs_mask] = cls._inverse_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[
|
|
56
|
+
idxs_mask
|
|
57
|
+
]
|
|
58
|
+
idxs_mask = ~idxs_mask
|
|
59
|
+
if np.any(idxs_mask):
|
|
60
|
+
x[idxs_mask] = cls._branching_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[
|
|
61
|
+
idxs_mask
|
|
62
|
+
]
|
|
63
|
+
return x
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def _inverse_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
|
|
67
|
+
# We handle x/0 and log(0) issues with branching
|
|
68
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
69
|
+
log_u = np.log(rng.uniform(size=dist_size))
|
|
70
|
+
pos_lam = lam > 0
|
|
71
|
+
abs_log_lam = np.log(np.abs(lam))
|
|
72
|
+
theta_m_lam = theta - lam
|
|
73
|
+
log_s = -theta
|
|
74
|
+
log_p = log_s.copy()
|
|
75
|
+
x_ = 0
|
|
76
|
+
x = np.zeros(dist_size)
|
|
77
|
+
below_cutpoint = log_s < log_u
|
|
78
|
+
while np.any(below_cutpoint[idxs_mask]):
|
|
79
|
+
x_ += 1
|
|
80
|
+
x[below_cutpoint] += 1
|
|
81
|
+
log_c = np.log(theta_m_lam + lam * x_)
|
|
82
|
+
# Compute log(1 + lam / C)
|
|
83
|
+
log1p_lam_m_C = np.where(
|
|
84
|
+
pos_lam,
|
|
85
|
+
np.log1p(np.exp(abs_log_lam - log_c)),
|
|
86
|
+
log1mexp(abs_log_lam - log_c),
|
|
87
|
+
)
|
|
88
|
+
log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam
|
|
89
|
+
log_s = np.logaddexp(log_s, log_p)
|
|
90
|
+
below_cutpoint = log_s < log_u
|
|
91
|
+
return x
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def _branching_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
|
|
95
|
+
lam_ = np.abs(lam) # This algorithm is only valid for positive lam
|
|
96
|
+
y = rng.poisson(theta, size=dist_size)
|
|
97
|
+
x = y.copy()
|
|
98
|
+
higher_than_zero = y > 0
|
|
99
|
+
while np.any(higher_than_zero[idxs_mask]):
|
|
100
|
+
y = rng.poisson(lam_ * y)
|
|
101
|
+
x[higher_than_zero] = x[higher_than_zero] + y[higher_than_zero]
|
|
102
|
+
higher_than_zero = y > 0
|
|
103
|
+
return x
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
generalized_poisson = GeneralizedPoissonRV()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class GeneralizedPoisson(pm.distributions.Discrete):
|
|
110
|
+
R"""
|
|
111
|
+
Generalized Poisson.
|
|
112
|
+
Used to model count data that can be either overdispersed or underdispersed.
|
|
113
|
+
Offers greater flexibility than the standard Poisson which assumes equidispersion,
|
|
114
|
+
where the mean is equal to the variance.
|
|
115
|
+
The pmf of this distribution is
|
|
116
|
+
|
|
117
|
+
.. math:: f(x \mid \mu, \lambda) =
|
|
118
|
+
\frac{\mu (\mu + \lambda x)^{x-1} e^{-\mu - \lambda x}}{x!}
|
|
119
|
+
======== ======================================
|
|
120
|
+
Support :math:`x \in \mathbb{N}_0`
|
|
121
|
+
Mean :math:`\frac{\mu}{1 - \lambda}`
|
|
122
|
+
Variance :math:`\frac{\mu}{(1 - \lambda)^3}`
|
|
123
|
+
======== ======================================
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
mu : tensor_like of float
|
|
128
|
+
Mean parameter (mu > 0).
|
|
129
|
+
lam : tensor_like of float
|
|
130
|
+
Dispersion parameter (max(-1, -mu/4) <= lam <= 1).
|
|
131
|
+
|
|
132
|
+
Notes
|
|
133
|
+
-----
|
|
134
|
+
When lam = 0, the Generalized Poisson reduces to the standard Poisson with the same mu.
|
|
135
|
+
When lam < 0, the mean is greater than the variance (underdispersion).
|
|
136
|
+
When lam > 0, the mean is less than the variance (overdispersion).
|
|
137
|
+
|
|
138
|
+
References
|
|
139
|
+
----------
|
|
140
|
+
The PMF is taken from [1] and the random generator function is adapted from [2].
|
|
141
|
+
.. [1] Consul, PoC, and Felix Famoye. "Generalized Poisson regression model."
|
|
142
|
+
Communications in Statistics-Theory and Methods 21.1 (1992): 89-109.
|
|
143
|
+
.. [2] Famoye, Felix. "Generalized Poisson random variate generation." American
|
|
144
|
+
Journal of Mathematical and Management Sciences 17.3-4 (1997): 219-237.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
rv_op = generalized_poisson
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def dist(cls, mu, lam, **kwargs):
|
|
151
|
+
mu = pt.as_tensor_variable(mu)
|
|
152
|
+
lam = pt.as_tensor_variable(lam)
|
|
153
|
+
return super().dist([mu, lam], **kwargs)
|
|
154
|
+
|
|
155
|
+
def support_point(rv, size, mu, lam):
|
|
156
|
+
mean = pt.floor(mu / (1 - lam))
|
|
157
|
+
if not rv_size_is_none(size):
|
|
158
|
+
mean = pt.full(size, mean)
|
|
159
|
+
return mean
|
|
160
|
+
|
|
161
|
+
def logp(value, mu, lam):
|
|
162
|
+
mu_lam_value = mu + lam * value
|
|
163
|
+
logprob = np.log(mu) + logpow(mu_lam_value, value - 1) - mu_lam_value - factln(value)
|
|
164
|
+
|
|
165
|
+
# Probability is 0 when value > m, where m is the largest positive integer for
|
|
166
|
+
# which mu + m * lam > 0 (when lam < 0).
|
|
167
|
+
logprob = pt.switch(
|
|
168
|
+
pt.or_(
|
|
169
|
+
mu_lam_value < 0,
|
|
170
|
+
value < 0,
|
|
171
|
+
),
|
|
172
|
+
-np.inf,
|
|
173
|
+
logprob,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return check_parameters(
|
|
177
|
+
logprob,
|
|
178
|
+
0 < mu,
|
|
179
|
+
pt.abs(lam) <= 1,
|
|
180
|
+
(-mu / 4) <= lam,
|
|
181
|
+
msg="0 < mu, max(-1, -mu/4)) <= lam <= 1",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class BetaNegativeBinomial:
|
|
186
|
+
R"""
|
|
187
|
+
Beta Negative Binomial distribution.
|
|
188
|
+
|
|
189
|
+
The pmf of this distribution is
|
|
190
|
+
|
|
191
|
+
.. math::
|
|
192
|
+
|
|
193
|
+
f(x \mid \alpha, \beta, r) = \frac{B(r + x, \alpha + \beta)}{B(r, \alpha)} \frac{\Gamma(x + \beta)}{x! \Gamma(\beta)}
|
|
194
|
+
|
|
195
|
+
where :math:`B` is the Beta function and :math:`\Gamma` is the Gamma function.
|
|
196
|
+
|
|
197
|
+
For more information, see https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution.
|
|
198
|
+
|
|
199
|
+
.. plot::
|
|
200
|
+
:context: close-figs
|
|
201
|
+
|
|
202
|
+
import matplotlib.pyplot as plt
|
|
203
|
+
import numpy as np
|
|
204
|
+
from scipy.special import betaln, gammaln
|
|
205
|
+
def factln(x):
|
|
206
|
+
return gammaln(x + 1)
|
|
207
|
+
def logp(x, alpha, beta, r):
|
|
208
|
+
return (
|
|
209
|
+
betaln(r + x, alpha + beta)
|
|
210
|
+
- betaln(r, alpha)
|
|
211
|
+
+ gammaln(x + beta)
|
|
212
|
+
- factln(x)
|
|
213
|
+
- gammaln(beta)
|
|
214
|
+
)
|
|
215
|
+
plt.style.use('arviz-darkgrid')
|
|
216
|
+
x = np.arange(0, 25)
|
|
217
|
+
params = [
|
|
218
|
+
(1, 1, 1),
|
|
219
|
+
(1, 1, 10),
|
|
220
|
+
(1, 10, 1),
|
|
221
|
+
(1, 10, 10),
|
|
222
|
+
(10, 10, 10),
|
|
223
|
+
]
|
|
224
|
+
for alpha, beta, r in params:
|
|
225
|
+
pmf = np.exp(logp(x, alpha, beta, r))
|
|
226
|
+
plt.plot(x, pmf, "-o", label=r'$alpha$ = {}, $beta$ = {}, $r$ = {}'.format(alpha, beta, r))
|
|
227
|
+
plt.xlabel('x', fontsize=12)
|
|
228
|
+
plt.ylabel('f(x)', fontsize=12)
|
|
229
|
+
plt.legend(loc=1)
|
|
230
|
+
plt.show()
|
|
231
|
+
|
|
232
|
+
======== ======================================
|
|
233
|
+
Support :math:`x \in \mathbb{N}_0`
|
|
234
|
+
Mean :math:`{\begin{cases}{\frac {r\beta }{\alpha -1}}&{\text{if}}\ \alpha >1\\\infty &{\text{otherwise}}\ \end{cases}}`
|
|
235
|
+
Variance :math:`{\displaystyle {\begin{cases}{\frac {r\beta (r+\alpha -1)(\beta +\alpha -1)}{(\alpha -2){(\alpha -1)}^{2}}}&{\text{if}}\ \alpha >2\\\infty &{\text{otherwise}}\ \end{cases}}}`
|
|
236
|
+
======== ======================================
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
alpha : tensor_like of float
|
|
241
|
+
shape of the beta distribution (alpha > 0).
|
|
242
|
+
beta : tensor_like of float
|
|
243
|
+
shape of the beta distribution (beta > 0).
|
|
244
|
+
r : tensor_like of float
|
|
245
|
+
number of successes until the experiment is stopped (integer but can be extended to real)
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
@staticmethod
|
|
249
|
+
def beta_negative_binomial_dist(alpha, beta, r, size):
|
|
250
|
+
if rv_size_is_none(size):
|
|
251
|
+
alpha, beta, r = pt.broadcast_arrays(alpha, beta, r)
|
|
252
|
+
|
|
253
|
+
p = pm.Beta.dist(alpha, beta, size=size)
|
|
254
|
+
return pm.NegativeBinomial.dist(p, r, size=size)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def beta_negative_binomial_logp(value, alpha, beta, r):
|
|
258
|
+
res = (
|
|
259
|
+
betaln(r + value, alpha + beta)
|
|
260
|
+
- betaln(r, alpha)
|
|
261
|
+
+ pt.gammaln(value + beta)
|
|
262
|
+
- factln(value)
|
|
263
|
+
- pt.gammaln(beta)
|
|
264
|
+
)
|
|
265
|
+
res = pt.switch(
|
|
266
|
+
pt.lt(value, 0),
|
|
267
|
+
-np.inf,
|
|
268
|
+
res,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return check_parameters(
|
|
272
|
+
res,
|
|
273
|
+
alpha > 0,
|
|
274
|
+
beta > 0,
|
|
275
|
+
r > 0,
|
|
276
|
+
msg="alpha > 0, beta > 0, r > 0",
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def __new__(cls, name, alpha, beta, r, **kwargs):
|
|
280
|
+
return pm.CustomDist(
|
|
281
|
+
name,
|
|
282
|
+
alpha,
|
|
283
|
+
beta,
|
|
284
|
+
r,
|
|
285
|
+
dist=cls.beta_negative_binomial_dist,
|
|
286
|
+
logp=cls.beta_negative_binomial_logp,
|
|
287
|
+
class_name="BetaNegativeBinomial",
|
|
288
|
+
**kwargs,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
@classmethod
|
|
292
|
+
def dist(cls, alpha, beta, r, **kwargs):
|
|
293
|
+
return pm.CustomDist.dist(
|
|
294
|
+
alpha,
|
|
295
|
+
beta,
|
|
296
|
+
r,
|
|
297
|
+
dist=cls.beta_negative_binomial_dist,
|
|
298
|
+
logp=cls.beta_negative_binomial_logp,
|
|
299
|
+
class_name="BetaNegativeBinomial",
|
|
300
|
+
**kwargs,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Skellam:
|
|
305
|
+
R"""
|
|
306
|
+
Skellam distribution.
|
|
307
|
+
|
|
308
|
+
The Skellam distribution is the distribution of the difference of two
|
|
309
|
+
Poisson random variables.
|
|
310
|
+
|
|
311
|
+
The pmf of this distribution is
|
|
312
|
+
|
|
313
|
+
.. math::
|
|
314
|
+
|
|
315
|
+
f(x | \mu_1, \mu_2) = e^{{-(\mu _{1}\!+\!\mu _{2})}}\left({\frac {\mu _{1}}{\mu _{2}}}\right)^{{x/2}}\!\!I_{{x}}(2{\sqrt {\mu _{1}\mu _{2}}})
|
|
316
|
+
|
|
317
|
+
where :math:`I_{x}` is the modified Bessel function of the first kind of order :math:`x`.
|
|
318
|
+
|
|
319
|
+
Read more about the Skellam distribution at https://en.wikipedia.org/wiki/Skellam_distribution
|
|
320
|
+
|
|
321
|
+
.. plot::
|
|
322
|
+
:context: close-figs
|
|
323
|
+
|
|
324
|
+
import matplotlib.pyplot as plt
|
|
325
|
+
import numpy as np
|
|
326
|
+
import scipy.stats as st
|
|
327
|
+
import arviz as az
|
|
328
|
+
plt.style.use('arviz-darkgrid')
|
|
329
|
+
x = np.arange(-15, 15)
|
|
330
|
+
params = [
|
|
331
|
+
(1, 1),
|
|
332
|
+
(5, 5),
|
|
333
|
+
(5, 1),
|
|
334
|
+
]
|
|
335
|
+
for mu1, mu2 in params:
|
|
336
|
+
pmf = st.skellam.pmf(x, mu1, mu2)
|
|
337
|
+
plt.plot(x, pmf, "-o", label=r'$\mu_1$ = {}, $\mu_2$ = {}'.format(mu1, mu2))
|
|
338
|
+
plt.xlabel('x', fontsize=12)
|
|
339
|
+
plt.ylabel('f(x)', fontsize=12)
|
|
340
|
+
plt.legend(loc=1)
|
|
341
|
+
plt.show()
|
|
342
|
+
|
|
343
|
+
======== ======================================
|
|
344
|
+
Support :math:`x \in \mathbb{Z}`
|
|
345
|
+
Mean :math:`\mu_{1} - \mu_{2}`
|
|
346
|
+
Variance :math:`\mu_{1} + \mu_{2}`
|
|
347
|
+
======== ======================================
|
|
348
|
+
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
mu1 : tensor_like of float
|
|
352
|
+
Mean parameter (mu1 >= 0).
|
|
353
|
+
mu2 : tensor_like of float
|
|
354
|
+
Mean parameter (mu2 >= 0).
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
def skellam_dist(mu1, mu2, size):
|
|
359
|
+
if rv_size_is_none(size):
|
|
360
|
+
mu1, mu2 = pt.broadcast_arrays(mu1, mu2)
|
|
361
|
+
|
|
362
|
+
return pm.Poisson.dist(mu=mu1, size=size) - pm.Poisson.dist(mu=mu2, size=size)
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def skellam_logp(value, mu1, mu2):
|
|
366
|
+
res = (
|
|
367
|
+
-mu1
|
|
368
|
+
- mu2
|
|
369
|
+
+ 0.5 * value * (pt.log(mu1) - pt.log(mu2))
|
|
370
|
+
+ pt.log(pt.iv(value, 2 * pt.sqrt(mu1 * mu2)))
|
|
371
|
+
)
|
|
372
|
+
return check_parameters(
|
|
373
|
+
res,
|
|
374
|
+
mu1 >= 0,
|
|
375
|
+
mu2 >= 0,
|
|
376
|
+
msg="mu1 >= 0, mu2 >= 0",
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def __new__(cls, name, mu1, mu2, **kwargs):
|
|
380
|
+
return pm.CustomDist(
|
|
381
|
+
name,
|
|
382
|
+
mu1,
|
|
383
|
+
mu2,
|
|
384
|
+
dist=cls.skellam_dist,
|
|
385
|
+
logp=cls.skellam_logp,
|
|
386
|
+
class_name="Skellam",
|
|
387
|
+
**kwargs,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
@classmethod
|
|
391
|
+
def dist(cls, mu1, mu2, **kwargs):
|
|
392
|
+
return pm.CustomDist.dist(
|
|
393
|
+
mu1,
|
|
394
|
+
mu2,
|
|
395
|
+
dist=cls.skellam_dist,
|
|
396
|
+
logp=cls.skellam_logp,
|
|
397
|
+
class_name="Skellam",
|
|
398
|
+
**kwargs,
|
|
399
|
+
)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pymc as pm
|
|
18
|
+
|
|
19
|
+
from numpy.typing import ArrayLike
|
|
20
|
+
|
|
21
|
+
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def quantile_histogram(
|
|
25
|
+
data: ArrayLike, n_quantiles=1000, zero_inflation=False
|
|
26
|
+
) -> dict[str, ArrayLike]:
|
|
27
|
+
try:
|
|
28
|
+
import xhistogram.core
|
|
29
|
+
except ImportError as e:
|
|
30
|
+
raise RuntimeError("quantile_histogram requires xhistogram package") from e
|
|
31
|
+
try:
|
|
32
|
+
import dask.array
|
|
33
|
+
import dask.dataframe
|
|
34
|
+
except ImportError:
|
|
35
|
+
dask = None
|
|
36
|
+
if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
|
|
37
|
+
data = data.to_dask_array(lengths=True)
|
|
38
|
+
if zero_inflation:
|
|
39
|
+
zeros = (data == 0).sum(0)
|
|
40
|
+
mdata = np.ma.masked_where(data == 0, data)
|
|
41
|
+
qdata = data[data > 0]
|
|
42
|
+
else:
|
|
43
|
+
mdata = data
|
|
44
|
+
qdata = data.flatten()
|
|
45
|
+
quantiles = np.percentile(qdata, np.linspace(0, 100, n_quantiles))
|
|
46
|
+
if dask:
|
|
47
|
+
(quantiles,) = dask.compute(quantiles)
|
|
48
|
+
count, _ = xhistogram.core.histogram(mdata, bins=[quantiles], axis=0)
|
|
49
|
+
count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
|
|
50
|
+
lower = quantiles[:-1]
|
|
51
|
+
upper = quantiles[1:]
|
|
52
|
+
|
|
53
|
+
if zero_inflation:
|
|
54
|
+
count = np.concatenate([zeros[None], count])
|
|
55
|
+
lower = np.concatenate([[0], lower])
|
|
56
|
+
upper = np.concatenate([[0], upper])
|
|
57
|
+
lower = lower.reshape(lower.shape + (1,) * (count.ndim - 1))
|
|
58
|
+
upper = upper.reshape(upper.shape + (1,) * (count.ndim - 1))
|
|
59
|
+
|
|
60
|
+
result = dict(
|
|
61
|
+
lower=lower,
|
|
62
|
+
upper=upper,
|
|
63
|
+
mid=(lower + upper) / 2,
|
|
64
|
+
count=count,
|
|
65
|
+
)
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def discrete_histogram(data: ArrayLike, min_count=None) -> dict[str, ArrayLike]:
|
|
70
|
+
try:
|
|
71
|
+
import xhistogram.core
|
|
72
|
+
except ImportError as e:
|
|
73
|
+
raise RuntimeError("discrete_histogram requires xhistogram package") from e
|
|
74
|
+
try:
|
|
75
|
+
import dask.array
|
|
76
|
+
import dask.dataframe
|
|
77
|
+
except ImportError:
|
|
78
|
+
dask = None
|
|
79
|
+
|
|
80
|
+
if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
|
|
81
|
+
data = data.to_dask_array(lengths=True)
|
|
82
|
+
mid, count_uniq = np.unique(data, return_counts=True)
|
|
83
|
+
if min_count is not None:
|
|
84
|
+
mid = mid[count_uniq >= min_count]
|
|
85
|
+
count_uniq = count_uniq[count_uniq >= min_count]
|
|
86
|
+
bins = np.concatenate([mid, [mid.max() + 1]])
|
|
87
|
+
if dask:
|
|
88
|
+
mid, bins = dask.compute(mid, bins)
|
|
89
|
+
count, _ = xhistogram.core.histogram(data, bins=[bins], axis=0)
|
|
90
|
+
count = count.transpose(count.ndim - 1, *range(count.ndim - 1))
|
|
91
|
+
mid = mid.reshape(mid.shape + (1,) * (count.ndim - 1))
|
|
92
|
+
return dict(mid=mid, count=count)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
96
|
+
"""Approximate a distribution with a histogram potential.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
name : str
|
|
101
|
+
Name for the Potential
|
|
102
|
+
dist : TensorVariable
|
|
103
|
+
The output of pm.Distribution.dist()
|
|
104
|
+
observed : ArrayLike
|
|
105
|
+
Observed value to construct a histogram. Histogram is computed over 0th axis.
|
|
106
|
+
Dask is supported.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
TensorVariable
|
|
111
|
+
Potential
|
|
112
|
+
|
|
113
|
+
Examples
|
|
114
|
+
--------
|
|
115
|
+
Discrete variables are reduced to unique repetitions (up to min_count)
|
|
116
|
+
|
|
117
|
+
>>> import pymc as pm
|
|
118
|
+
>>> import pymc_extras as pmx
|
|
119
|
+
>>> production = np.random.poisson([1, 2, 5], size=(1000, 3))
|
|
120
|
+
>>> with pm.Model(coords=dict(plant=range(3))):
|
|
121
|
+
... lam = pm.Exponential("lam", 1.0, dims="plant")
|
|
122
|
+
... pot = pmx.distributions.histogram_approximation(
|
|
123
|
+
... "pot", pm.Poisson.dist(lam), observed=production, min_count=2
|
|
124
|
+
... )
|
|
125
|
+
|
|
126
|
+
Continuous variables are discretized into n_quantiles
|
|
127
|
+
|
|
128
|
+
>>> measurements = np.random.normal([1, 2, 3], [0.1, 0.4, 0.2], size=(10000, 3))
|
|
129
|
+
>>> with pm.Model(coords=dict(tests=range(3))):
|
|
130
|
+
... m = pm.Normal("m", dims="tests")
|
|
131
|
+
... s = pm.LogNormal("s", dims="tests")
|
|
132
|
+
... pot = pmx.distributions.histogram_approximation(
|
|
133
|
+
... "pot", pm.Normal.dist(m, s),
|
|
134
|
+
... observed=measurements, n_quantiles=50
|
|
135
|
+
... )
|
|
136
|
+
|
|
137
|
+
For special cases like Zero Inflation in Continuous variables there is a flag.
|
|
138
|
+
The flag adds a separate bin for zeros
|
|
139
|
+
|
|
140
|
+
>>> measurements = abs(measurements)
|
|
141
|
+
>>> measurements[100:] = 0
|
|
142
|
+
>>> with pm.Model(coords=dict(tests=range(3))):
|
|
143
|
+
... m = pm.Normal("m", dims="tests")
|
|
144
|
+
... s = pm.LogNormal("s", dims="tests")
|
|
145
|
+
... pot = pmx.distributions.histogram_approximation(
|
|
146
|
+
... "pot", pm.Normal.dist(m, s),
|
|
147
|
+
... observed=measurements, n_quantiles=50, zero_inflation=True
|
|
148
|
+
... )
|
|
149
|
+
"""
|
|
150
|
+
try:
|
|
151
|
+
import dask.array
|
|
152
|
+
import dask.dataframe
|
|
153
|
+
except ImportError:
|
|
154
|
+
dask = None
|
|
155
|
+
if dask and isinstance(observed, dask.dataframe.Series | dask.dataframe.DataFrame):
|
|
156
|
+
observed = observed.to_dask_array(lengths=True)
|
|
157
|
+
if np.issubdtype(observed.dtype, np.integer):
|
|
158
|
+
histogram = discrete_histogram(observed, **h_kwargs)
|
|
159
|
+
else:
|
|
160
|
+
histogram = quantile_histogram(observed, **h_kwargs)
|
|
161
|
+
if dask is not None:
|
|
162
|
+
(histogram,) = dask.compute(histogram)
|
|
163
|
+
return pm.Potential(name, pm.logp(dist, histogram["mid"]) * histogram["count"])
|