pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
pymc_extras/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
from pymc_extras import gp, statespace, utils
|
|
17
|
+
from pymc_extras.distributions import *
|
|
18
|
+
from pymc_extras.inference.fit import fit
|
|
19
|
+
from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
|
|
20
|
+
from pymc_extras.model.model_api import as_model
|
|
21
|
+
from pymc_extras.version import __version__
|
|
22
|
+
|
|
23
|
+
_log = logging.getLogger("pmx")
|
|
24
|
+
|
|
25
|
+
if not logging.root.handlers:
|
|
26
|
+
_log.setLevel(logging.INFO)
|
|
27
|
+
if len(_log.handlers) == 0:
|
|
28
|
+
handler = logging.StreamHandler()
|
|
29
|
+
_log.addHandler(handler)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# coding: utf-8
|
|
16
|
+
"""
|
|
17
|
+
Experimental probability distributions for stochastic nodes in PyMC.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from pymc_extras.distributions.continuous import Chi, GenExtreme, Maxwell
|
|
21
|
+
from pymc_extras.distributions.discrete import (
|
|
22
|
+
BetaNegativeBinomial,
|
|
23
|
+
GeneralizedPoisson,
|
|
24
|
+
Skellam,
|
|
25
|
+
)
|
|
26
|
+
from pymc_extras.distributions.histogram_utils import histogram_approximation
|
|
27
|
+
from pymc_extras.distributions.multivariate import R2D2M2CP
|
|
28
|
+
from pymc_extras.distributions.timeseries import DiscreteMarkovChain
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"Chi",
|
|
32
|
+
"Maxwell",
|
|
33
|
+
"DiscreteMarkovChain",
|
|
34
|
+
"GeneralizedPoisson",
|
|
35
|
+
"BetaNegativeBinomial",
|
|
36
|
+
"GenExtreme",
|
|
37
|
+
"R2D2M2CP",
|
|
38
|
+
"Skellam",
|
|
39
|
+
"histogram_approximation",
|
|
40
|
+
]
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# coding: utf-8
|
|
16
|
+
"""
|
|
17
|
+
Experimental probability distributions for stochastic nodes in PyMC.
|
|
18
|
+
|
|
19
|
+
The imports from pymc are not fully replicated here: add imports as necessary.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pytensor.tensor as pt
|
|
24
|
+
|
|
25
|
+
from pymc import ChiSquared, CustomDist
|
|
26
|
+
from pymc.distributions import transforms
|
|
27
|
+
from pymc.distributions.dist_math import check_parameters
|
|
28
|
+
from pymc.distributions.distribution import Continuous
|
|
29
|
+
from pymc.distributions.shape_utils import rv_size_is_none
|
|
30
|
+
from pymc.logprob.utils import CheckParameterValue
|
|
31
|
+
from pymc.pytensorf import floatX
|
|
32
|
+
from pytensor.tensor.random.op import RandomVariable
|
|
33
|
+
from pytensor.tensor.variable import TensorVariable
|
|
34
|
+
from scipy import stats
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GenExtremeRV(RandomVariable):
|
|
38
|
+
name: str = "Generalized Extreme Value"
|
|
39
|
+
signature = "(),(),()->()"
|
|
40
|
+
dtype: str = "floatX"
|
|
41
|
+
_print_name: tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
|
|
42
|
+
|
|
43
|
+
def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable:
|
|
44
|
+
return super().__call__(mu, sigma, xi, size=size, **kwargs)
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def rng_fn(
|
|
48
|
+
cls,
|
|
49
|
+
rng: np.random.RandomState | np.random.Generator,
|
|
50
|
+
mu: np.ndarray,
|
|
51
|
+
sigma: np.ndarray,
|
|
52
|
+
xi: np.ndarray,
|
|
53
|
+
size: tuple[int, ...],
|
|
54
|
+
) -> np.ndarray:
|
|
55
|
+
# Notice negative here, since remainder of GenExtreme is based on Coles parametrization
|
|
56
|
+
return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
gev = GenExtremeRV()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GenExtreme(Continuous):
|
|
63
|
+
r"""
|
|
64
|
+
Univariate Generalized Extreme Value log-likelihood
|
|
65
|
+
|
|
66
|
+
The cdf of this distribution is
|
|
67
|
+
|
|
68
|
+
.. math::
|
|
69
|
+
|
|
70
|
+
G(x \mid \mu, \sigma, \xi) = \exp\left[ -\left(1 + \xi z\right)^{-\frac{1}{\xi}} \right]
|
|
71
|
+
|
|
72
|
+
where
|
|
73
|
+
|
|
74
|
+
.. math::
|
|
75
|
+
|
|
76
|
+
z = \frac{x - \mu}{\sigma}
|
|
77
|
+
|
|
78
|
+
and is defined on the set:
|
|
79
|
+
|
|
80
|
+
.. math::
|
|
81
|
+
|
|
82
|
+
\left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}.
|
|
83
|
+
|
|
84
|
+
Note that this parametrization is per Coles (2001), and differs from that of
|
|
85
|
+
Scipy in the sign of the shape parameter, :math:`\xi`.
|
|
86
|
+
|
|
87
|
+
.. plot::
|
|
88
|
+
|
|
89
|
+
import matplotlib.pyplot as plt
|
|
90
|
+
import numpy as np
|
|
91
|
+
import scipy.stats as st
|
|
92
|
+
import arviz as az
|
|
93
|
+
plt.style.use('arviz-darkgrid')
|
|
94
|
+
x = np.linspace(-10, 20, 200)
|
|
95
|
+
mus = [0., 4., -1.]
|
|
96
|
+
sigmas = [2., 2., 4.]
|
|
97
|
+
xis = [-0.3, 0.0, 0.3]
|
|
98
|
+
for mu, sigma, xi in zip(mus, sigmas, xis):
|
|
99
|
+
pdf = st.genextreme.pdf(x, c=-xi, loc=mu, scale=sigma)
|
|
100
|
+
plt.plot(x, pdf, label=rf'$\mu$ = {mu}, $\sigma$ = {sigma}, $\xi$={xi}')
|
|
101
|
+
plt.xlabel('x', fontsize=12)
|
|
102
|
+
plt.ylabel('f(x)', fontsize=12)
|
|
103
|
+
plt.legend(loc=1)
|
|
104
|
+
plt.show()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
======== =========================================================================
|
|
108
|
+
Support * :math:`x \in [\mu - \sigma/\xi, +\infty]`, when :math:`\xi > 0`
|
|
109
|
+
* :math:`x \in \mathbb{R}` when :math:`\xi = 0`
|
|
110
|
+
* :math:`x \in [-\infty, \mu - \sigma/\xi]`, when :math:`\xi < 0`
|
|
111
|
+
Mean * :math:`\mu + \sigma(g_1 - 1)/\xi`, when :math:`\xi \neq 0, \xi < 1`
|
|
112
|
+
* :math:`\mu + \sigma \gamma`, when :math:`\xi = 0`
|
|
113
|
+
* :math:`\infty`, when :math:`\xi \geq 1`
|
|
114
|
+
where :math:`\gamma` is the Euler-Mascheroni constant, and
|
|
115
|
+
:math:`g_k = \Gamma (1-k\xi)`
|
|
116
|
+
Variance * :math:`\sigma^2 (g_2 - g_1^2)/\xi^2`, when :math:`\xi \neq 0, \xi < 0.5`
|
|
117
|
+
* :math:`\frac{\pi^2}{6} \sigma^2`, when :math:`\xi = 0`
|
|
118
|
+
* :math:`\infty`, when :math:`\xi \geq 0.5`
|
|
119
|
+
======== =========================================================================
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
mu : float
|
|
124
|
+
Location parameter.
|
|
125
|
+
sigma : float
|
|
126
|
+
Scale parameter (sigma > 0).
|
|
127
|
+
xi : float
|
|
128
|
+
Shape parameter
|
|
129
|
+
scipy : bool
|
|
130
|
+
Whether or not to use the Scipy interpretation of the shape parameter
|
|
131
|
+
(defaults to `False`).
|
|
132
|
+
|
|
133
|
+
References
|
|
134
|
+
----------
|
|
135
|
+
.. [Coles2001] Coles, S.G. (2001).
|
|
136
|
+
An Introduction to the Statistical Modeling of Extreme Values
|
|
137
|
+
Springer-Verlag, London
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
rv_op = gev
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def dist(cls, mu=0, sigma=1, xi=0, scipy=False, **kwargs):
|
|
145
|
+
# If SciPy, use its parametrization, otherwise convert to standard
|
|
146
|
+
if scipy:
|
|
147
|
+
xi = -xi
|
|
148
|
+
mu = pt.as_tensor_variable(floatX(mu))
|
|
149
|
+
sigma = pt.as_tensor_variable(floatX(sigma))
|
|
150
|
+
xi = pt.as_tensor_variable(floatX(xi))
|
|
151
|
+
|
|
152
|
+
return super().dist([mu, sigma, xi], **kwargs)
|
|
153
|
+
|
|
154
|
+
def logp(value, mu, sigma, xi):
|
|
155
|
+
"""
|
|
156
|
+
Calculate log-probability of Generalized Extreme Value distribution
|
|
157
|
+
at specified value.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
value: numeric
|
|
162
|
+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
|
|
163
|
+
values are desired the values must be provided in a numpy array or Pytensor tensor
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
TensorVariable
|
|
168
|
+
"""
|
|
169
|
+
scaled = (value - mu) / sigma
|
|
170
|
+
|
|
171
|
+
logp_expression = pt.switch(
|
|
172
|
+
pt.isclose(xi, 0),
|
|
173
|
+
-pt.log(sigma) - scaled - pt.exp(-scaled),
|
|
174
|
+
-pt.log(sigma)
|
|
175
|
+
- ((xi + 1) / xi) * pt.log1p(xi * scaled)
|
|
176
|
+
- pt.pow(1 + xi * scaled, -1 / xi),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
logp = pt.switch(pt.gt(1 + xi * scaled, 0.0), logp_expression, -np.inf)
|
|
180
|
+
|
|
181
|
+
return check_parameters(
|
|
182
|
+
logp, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def logcdf(value, mu, sigma, xi):
|
|
186
|
+
"""
|
|
187
|
+
Compute the log of the cumulative distribution function for Generalized Extreme Value
|
|
188
|
+
distribution at the specified value.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
value: numeric or np.ndarray or `TensorVariable`
|
|
193
|
+
Value(s) for which log CDF is calculated. If the log CDF for
|
|
194
|
+
multiple values are desired the values must be provided in a numpy
|
|
195
|
+
array or `TensorVariable`.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
TensorVariable
|
|
200
|
+
"""
|
|
201
|
+
scaled = (value - mu) / sigma
|
|
202
|
+
logc_expression = pt.switch(
|
|
203
|
+
pt.isclose(xi, 0), -pt.exp(-scaled), -pt.pow(1 + xi * scaled, -1 / xi)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
logc = pt.switch(1 + xi * (value - mu) / sigma > 0, logc_expression, -np.inf)
|
|
207
|
+
|
|
208
|
+
return check_parameters(
|
|
209
|
+
logc, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def support_point(rv, size, mu, sigma, xi):
|
|
213
|
+
r"""
|
|
214
|
+
Using the mode, as the mean can be infinite when :math:`\xi > 1`
|
|
215
|
+
"""
|
|
216
|
+
mode = pt.switch(pt.isclose(xi, 0), mu, mu + sigma * (pt.pow(1 + xi, -xi) - 1) / xi)
|
|
217
|
+
if not rv_size_is_none(size):
|
|
218
|
+
mode = pt.full(size, mode)
|
|
219
|
+
return mode
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class Chi:
|
|
223
|
+
r"""
|
|
224
|
+
:math:`\chi` log-likelihood.
|
|
225
|
+
|
|
226
|
+
The pdf of this distribution is
|
|
227
|
+
|
|
228
|
+
.. math::
|
|
229
|
+
|
|
230
|
+
f(x \mid \nu) = \frac{x^{\nu - 1}e^{-x^2/2}}{2^{\nu/2 - 1}\Gamma(\nu/2)}
|
|
231
|
+
|
|
232
|
+
.. plot::
|
|
233
|
+
:context: close-figs
|
|
234
|
+
|
|
235
|
+
import matplotlib.pyplot as plt
|
|
236
|
+
import numpy as np
|
|
237
|
+
import scipy.stats as st
|
|
238
|
+
import arviz as az
|
|
239
|
+
plt.style.use('arviz-darkgrid')
|
|
240
|
+
x = np.linspace(0, 10, 200)
|
|
241
|
+
for df in [1, 2, 3, 6, 9]:
|
|
242
|
+
pdf = st.chi.pdf(x, df)
|
|
243
|
+
plt.plot(x, pdf, label=r'$\nu$ = {}'.format(df))
|
|
244
|
+
plt.xlabel('x', fontsize=12)
|
|
245
|
+
plt.ylabel('f(x)', fontsize=12)
|
|
246
|
+
plt.legend(loc=1)
|
|
247
|
+
plt.show()
|
|
248
|
+
|
|
249
|
+
======== =========================================================================
|
|
250
|
+
Support :math:`x \in [0, \infty)`
|
|
251
|
+
Mean :math:`\sqrt{2}\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}`
|
|
252
|
+
Variance :math:`\nu - 2\left(\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}\right)^2`
|
|
253
|
+
======== =========================================================================
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
nu : tensor_like of float
|
|
258
|
+
Degrees of freedom (nu > 0).
|
|
259
|
+
|
|
260
|
+
Examples
|
|
261
|
+
--------
|
|
262
|
+
.. code-block:: python
|
|
263
|
+
import pymc as pm
|
|
264
|
+
from pymc_extras.distributions import Chi
|
|
265
|
+
|
|
266
|
+
with pm.Model():
|
|
267
|
+
x = Chi('x', nu=1)
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
|
|
272
|
+
return pt.math.sqrt(ChiSquared.dist(nu=nu, size=size))
|
|
273
|
+
|
|
274
|
+
def __new__(cls, name, nu, **kwargs):
|
|
275
|
+
if "observed" not in kwargs:
|
|
276
|
+
kwargs.setdefault("default_transform", transforms.log)
|
|
277
|
+
return CustomDist(name, nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
|
|
278
|
+
|
|
279
|
+
@classmethod
|
|
280
|
+
def dist(cls, nu, **kwargs):
|
|
281
|
+
return CustomDist.dist(nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class Maxwell:
|
|
285
|
+
R"""
|
|
286
|
+
The Maxwell-Boltzmann distribution
|
|
287
|
+
|
|
288
|
+
The pdf of this distribution is
|
|
289
|
+
|
|
290
|
+
.. math::
|
|
291
|
+
|
|
292
|
+
f(x \mid a) = {\displaystyle {\sqrt {\frac {2}{\pi }}}\,{\frac {x^{2}}{a^{3}}}\,\exp \left({\frac {-x^{2}}{2a^{2}}}\right)}
|
|
293
|
+
|
|
294
|
+
Read more about it on `Wikipedia <https://en.wikipedia.org/wiki/Maxwell%E2%80%93Boltzmann_distribution>`_
|
|
295
|
+
|
|
296
|
+
.. plot::
|
|
297
|
+
:context: close-figs
|
|
298
|
+
|
|
299
|
+
import matplotlib.pyplot as plt
|
|
300
|
+
import numpy as np
|
|
301
|
+
import scipy.stats as st
|
|
302
|
+
import arviz as az
|
|
303
|
+
plt.style.use('arviz-darkgrid')
|
|
304
|
+
x = np.linspace(0, 20, 200)
|
|
305
|
+
for a in [1, 2, 5]:
|
|
306
|
+
pdf = st.maxwell.pdf(x, scale=a)
|
|
307
|
+
plt.plot(x, pdf, label=r'$a$ = {}'.format(a))
|
|
308
|
+
plt.xlabel('x', fontsize=12)
|
|
309
|
+
plt.ylabel('f(x)', fontsize=12)
|
|
310
|
+
plt.legend(loc=1)
|
|
311
|
+
plt.show()
|
|
312
|
+
|
|
313
|
+
======== =========================================================================
|
|
314
|
+
Support :math:`x \in (0, \infty)`
|
|
315
|
+
Mean :math:`2a \sqrt{\frac{2}{\pi}}`
|
|
316
|
+
Variance :math:`\frac{a^2(3 \pi - 8)}{\pi}`
|
|
317
|
+
======== =========================================================================
|
|
318
|
+
|
|
319
|
+
Parameters
|
|
320
|
+
----------
|
|
321
|
+
a : tensor_like of float
|
|
322
|
+
Scale parameter (a > 0).
|
|
323
|
+
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable:
|
|
328
|
+
if rv_size_is_none(size):
|
|
329
|
+
size = a.shape
|
|
330
|
+
|
|
331
|
+
a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0)))
|
|
332
|
+
|
|
333
|
+
return Chi.dist(nu=3, size=size) * a
|
|
334
|
+
|
|
335
|
+
def __new__(cls, name, a, **kwargs):
|
|
336
|
+
return CustomDist(
|
|
337
|
+
name,
|
|
338
|
+
a,
|
|
339
|
+
dist=cls.maxwell_dist,
|
|
340
|
+
class_name="Maxwell",
|
|
341
|
+
**kwargs,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
@classmethod
|
|
345
|
+
def dist(cls, a, **kwargs):
|
|
346
|
+
return CustomDist.dist(
|
|
347
|
+
a,
|
|
348
|
+
dist=cls.maxwell_dist,
|
|
349
|
+
class_name="Maxwell",
|
|
350
|
+
**kwargs,
|
|
351
|
+
)
|