pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,29 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+
16
+ from pymc_extras import gp, statespace, utils
17
+ from pymc_extras.distributions import *
18
+ from pymc_extras.inference.fit import fit
19
+ from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
20
+ from pymc_extras.model.model_api import as_model
21
+ from pymc_extras.version import __version__
22
+
23
+ _log = logging.getLogger("pmx")
24
+
25
+ if not logging.root.handlers:
26
+ _log.setLevel(logging.INFO)
27
+ if len(_log.handlers) == 0:
28
+ handler = logging.StreamHandler()
29
+ _log.addHandler(handler)
@@ -0,0 +1,40 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # coding: utf-8
16
+ """
17
+ Experimental probability distributions for stochastic nodes in PyMC.
18
+ """
19
+
20
+ from pymc_extras.distributions.continuous import Chi, GenExtreme, Maxwell
21
+ from pymc_extras.distributions.discrete import (
22
+ BetaNegativeBinomial,
23
+ GeneralizedPoisson,
24
+ Skellam,
25
+ )
26
+ from pymc_extras.distributions.histogram_utils import histogram_approximation
27
+ from pymc_extras.distributions.multivariate import R2D2M2CP
28
+ from pymc_extras.distributions.timeseries import DiscreteMarkovChain
29
+
30
+ __all__ = [
31
+ "Chi",
32
+ "Maxwell",
33
+ "DiscreteMarkovChain",
34
+ "GeneralizedPoisson",
35
+ "BetaNegativeBinomial",
36
+ "GenExtreme",
37
+ "R2D2M2CP",
38
+ "Skellam",
39
+ "histogram_approximation",
40
+ ]
@@ -0,0 +1,351 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # coding: utf-8
16
+ """
17
+ Experimental probability distributions for stochastic nodes in PyMC.
18
+
19
+ The imports from pymc are not fully replicated here: add imports as necessary.
20
+ """
21
+
22
+ import numpy as np
23
+ import pytensor.tensor as pt
24
+
25
+ from pymc import ChiSquared, CustomDist
26
+ from pymc.distributions import transforms
27
+ from pymc.distributions.dist_math import check_parameters
28
+ from pymc.distributions.distribution import Continuous
29
+ from pymc.distributions.shape_utils import rv_size_is_none
30
+ from pymc.logprob.utils import CheckParameterValue
31
+ from pymc.pytensorf import floatX
32
+ from pytensor.tensor.random.op import RandomVariable
33
+ from pytensor.tensor.variable import TensorVariable
34
+ from scipy import stats
35
+
36
+
37
+ class GenExtremeRV(RandomVariable):
38
+ name: str = "Generalized Extreme Value"
39
+ signature = "(),(),()->()"
40
+ dtype: str = "floatX"
41
+ _print_name: tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
42
+
43
+ def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable:
44
+ return super().__call__(mu, sigma, xi, size=size, **kwargs)
45
+
46
+ @classmethod
47
+ def rng_fn(
48
+ cls,
49
+ rng: np.random.RandomState | np.random.Generator,
50
+ mu: np.ndarray,
51
+ sigma: np.ndarray,
52
+ xi: np.ndarray,
53
+ size: tuple[int, ...],
54
+ ) -> np.ndarray:
55
+ # Notice negative here, since remainder of GenExtreme is based on Coles parametrization
56
+ return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size)
57
+
58
+
59
+ gev = GenExtremeRV()
60
+
61
+
62
+ class GenExtreme(Continuous):
63
+ r"""
64
+ Univariate Generalized Extreme Value log-likelihood
65
+
66
+ The cdf of this distribution is
67
+
68
+ .. math::
69
+
70
+ G(x \mid \mu, \sigma, \xi) = \exp\left[ -\left(1 + \xi z\right)^{-\frac{1}{\xi}} \right]
71
+
72
+ where
73
+
74
+ .. math::
75
+
76
+ z = \frac{x - \mu}{\sigma}
77
+
78
+ and is defined on the set:
79
+
80
+ .. math::
81
+
82
+ \left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}.
83
+
84
+ Note that this parametrization is per Coles (2001), and differs from that of
85
+ Scipy in the sign of the shape parameter, :math:`\xi`.
86
+
87
+ .. plot::
88
+
89
+ import matplotlib.pyplot as plt
90
+ import numpy as np
91
+ import scipy.stats as st
92
+ import arviz as az
93
+ plt.style.use('arviz-darkgrid')
94
+ x = np.linspace(-10, 20, 200)
95
+ mus = [0., 4., -1.]
96
+ sigmas = [2., 2., 4.]
97
+ xis = [-0.3, 0.0, 0.3]
98
+ for mu, sigma, xi in zip(mus, sigmas, xis):
99
+ pdf = st.genextreme.pdf(x, c=-xi, loc=mu, scale=sigma)
100
+ plt.plot(x, pdf, label=rf'$\mu$ = {mu}, $\sigma$ = {sigma}, $\xi$={xi}')
101
+ plt.xlabel('x', fontsize=12)
102
+ plt.ylabel('f(x)', fontsize=12)
103
+ plt.legend(loc=1)
104
+ plt.show()
105
+
106
+
107
+ ======== =========================================================================
108
+ Support * :math:`x \in [\mu - \sigma/\xi, +\infty]`, when :math:`\xi > 0`
109
+ * :math:`x \in \mathbb{R}` when :math:`\xi = 0`
110
+ * :math:`x \in [-\infty, \mu - \sigma/\xi]`, when :math:`\xi < 0`
111
+ Mean * :math:`\mu + \sigma(g_1 - 1)/\xi`, when :math:`\xi \neq 0, \xi < 1`
112
+ * :math:`\mu + \sigma \gamma`, when :math:`\xi = 0`
113
+ * :math:`\infty`, when :math:`\xi \geq 1`
114
+ where :math:`\gamma` is the Euler-Mascheroni constant, and
115
+ :math:`g_k = \Gamma (1-k\xi)`
116
+ Variance * :math:`\sigma^2 (g_2 - g_1^2)/\xi^2`, when :math:`\xi \neq 0, \xi < 0.5`
117
+ * :math:`\frac{\pi^2}{6} \sigma^2`, when :math:`\xi = 0`
118
+ * :math:`\infty`, when :math:`\xi \geq 0.5`
119
+ ======== =========================================================================
120
+
121
+ Parameters
122
+ ----------
123
+ mu : float
124
+ Location parameter.
125
+ sigma : float
126
+ Scale parameter (sigma > 0).
127
+ xi : float
128
+ Shape parameter
129
+ scipy : bool
130
+ Whether or not to use the Scipy interpretation of the shape parameter
131
+ (defaults to `False`).
132
+
133
+ References
134
+ ----------
135
+ .. [Coles2001] Coles, S.G. (2001).
136
+ An Introduction to the Statistical Modeling of Extreme Values
137
+ Springer-Verlag, London
138
+
139
+ """
140
+
141
+ rv_op = gev
142
+
143
+ @classmethod
144
+ def dist(cls, mu=0, sigma=1, xi=0, scipy=False, **kwargs):
145
+ # If SciPy, use its parametrization, otherwise convert to standard
146
+ if scipy:
147
+ xi = -xi
148
+ mu = pt.as_tensor_variable(floatX(mu))
149
+ sigma = pt.as_tensor_variable(floatX(sigma))
150
+ xi = pt.as_tensor_variable(floatX(xi))
151
+
152
+ return super().dist([mu, sigma, xi], **kwargs)
153
+
154
+ def logp(value, mu, sigma, xi):
155
+ """
156
+ Calculate log-probability of Generalized Extreme Value distribution
157
+ at specified value.
158
+
159
+ Parameters
160
+ ----------
161
+ value: numeric
162
+ Value(s) for which log-probability is calculated. If the log probabilities for multiple
163
+ values are desired the values must be provided in a numpy array or Pytensor tensor
164
+
165
+ Returns
166
+ -------
167
+ TensorVariable
168
+ """
169
+ scaled = (value - mu) / sigma
170
+
171
+ logp_expression = pt.switch(
172
+ pt.isclose(xi, 0),
173
+ -pt.log(sigma) - scaled - pt.exp(-scaled),
174
+ -pt.log(sigma)
175
+ - ((xi + 1) / xi) * pt.log1p(xi * scaled)
176
+ - pt.pow(1 + xi * scaled, -1 / xi),
177
+ )
178
+
179
+ logp = pt.switch(pt.gt(1 + xi * scaled, 0.0), logp_expression, -np.inf)
180
+
181
+ return check_parameters(
182
+ logp, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
183
+ )
184
+
185
+ def logcdf(value, mu, sigma, xi):
186
+ """
187
+ Compute the log of the cumulative distribution function for Generalized Extreme Value
188
+ distribution at the specified value.
189
+
190
+ Parameters
191
+ ----------
192
+ value: numeric or np.ndarray or `TensorVariable`
193
+ Value(s) for which log CDF is calculated. If the log CDF for
194
+ multiple values are desired the values must be provided in a numpy
195
+ array or `TensorVariable`.
196
+
197
+ Returns
198
+ -------
199
+ TensorVariable
200
+ """
201
+ scaled = (value - mu) / sigma
202
+ logc_expression = pt.switch(
203
+ pt.isclose(xi, 0), -pt.exp(-scaled), -pt.pow(1 + xi * scaled, -1 / xi)
204
+ )
205
+
206
+ logc = pt.switch(1 + xi * (value - mu) / sigma > 0, logc_expression, -np.inf)
207
+
208
+ return check_parameters(
209
+ logc, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
210
+ )
211
+
212
+ def support_point(rv, size, mu, sigma, xi):
213
+ r"""
214
+ Using the mode, as the mean can be infinite when :math:`\xi > 1`
215
+ """
216
+ mode = pt.switch(pt.isclose(xi, 0), mu, mu + sigma * (pt.pow(1 + xi, -xi) - 1) / xi)
217
+ if not rv_size_is_none(size):
218
+ mode = pt.full(size, mode)
219
+ return mode
220
+
221
+
222
+ class Chi:
223
+ r"""
224
+ :math:`\chi` log-likelihood.
225
+
226
+ The pdf of this distribution is
227
+
228
+ .. math::
229
+
230
+ f(x \mid \nu) = \frac{x^{\nu - 1}e^{-x^2/2}}{2^{\nu/2 - 1}\Gamma(\nu/2)}
231
+
232
+ .. plot::
233
+ :context: close-figs
234
+
235
+ import matplotlib.pyplot as plt
236
+ import numpy as np
237
+ import scipy.stats as st
238
+ import arviz as az
239
+ plt.style.use('arviz-darkgrid')
240
+ x = np.linspace(0, 10, 200)
241
+ for df in [1, 2, 3, 6, 9]:
242
+ pdf = st.chi.pdf(x, df)
243
+ plt.plot(x, pdf, label=r'$\nu$ = {}'.format(df))
244
+ plt.xlabel('x', fontsize=12)
245
+ plt.ylabel('f(x)', fontsize=12)
246
+ plt.legend(loc=1)
247
+ plt.show()
248
+
249
+ ======== =========================================================================
250
+ Support :math:`x \in [0, \infty)`
251
+ Mean :math:`\sqrt{2}\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}`
252
+ Variance :math:`\nu - 2\left(\frac{\Gamma((\nu + 1)/2)}{\Gamma(\nu/2)}\right)^2`
253
+ ======== =========================================================================
254
+
255
+ Parameters
256
+ ----------
257
+ nu : tensor_like of float
258
+ Degrees of freedom (nu > 0).
259
+
260
+ Examples
261
+ --------
262
+ .. code-block:: python
263
+ import pymc as pm
264
+ from pymc_extras.distributions import Chi
265
+
266
+ with pm.Model():
267
+ x = Chi('x', nu=1)
268
+ """
269
+
270
+ @staticmethod
271
+ def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
272
+ return pt.math.sqrt(ChiSquared.dist(nu=nu, size=size))
273
+
274
+ def __new__(cls, name, nu, **kwargs):
275
+ if "observed" not in kwargs:
276
+ kwargs.setdefault("default_transform", transforms.log)
277
+ return CustomDist(name, nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
278
+
279
+ @classmethod
280
+ def dist(cls, nu, **kwargs):
281
+ return CustomDist.dist(nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
282
+
283
+
284
+ class Maxwell:
285
+ R"""
286
+ The Maxwell-Boltzmann distribution
287
+
288
+ The pdf of this distribution is
289
+
290
+ .. math::
291
+
292
+ f(x \mid a) = {\displaystyle {\sqrt {\frac {2}{\pi }}}\,{\frac {x^{2}}{a^{3}}}\,\exp \left({\frac {-x^{2}}{2a^{2}}}\right)}
293
+
294
+ Read more about it on `Wikipedia <https://en.wikipedia.org/wiki/Maxwell%E2%80%93Boltzmann_distribution>`_
295
+
296
+ .. plot::
297
+ :context: close-figs
298
+
299
+ import matplotlib.pyplot as plt
300
+ import numpy as np
301
+ import scipy.stats as st
302
+ import arviz as az
303
+ plt.style.use('arviz-darkgrid')
304
+ x = np.linspace(0, 20, 200)
305
+ for a in [1, 2, 5]:
306
+ pdf = st.maxwell.pdf(x, scale=a)
307
+ plt.plot(x, pdf, label=r'$a$ = {}'.format(a))
308
+ plt.xlabel('x', fontsize=12)
309
+ plt.ylabel('f(x)', fontsize=12)
310
+ plt.legend(loc=1)
311
+ plt.show()
312
+
313
+ ======== =========================================================================
314
+ Support :math:`x \in (0, \infty)`
315
+ Mean :math:`2a \sqrt{\frac{2}{\pi}}`
316
+ Variance :math:`\frac{a^2(3 \pi - 8)}{\pi}`
317
+ ======== =========================================================================
318
+
319
+ Parameters
320
+ ----------
321
+ a : tensor_like of float
322
+ Scale parameter (a > 0).
323
+
324
+ """
325
+
326
+ @staticmethod
327
+ def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable:
328
+ if rv_size_is_none(size):
329
+ size = a.shape
330
+
331
+ a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0)))
332
+
333
+ return Chi.dist(nu=3, size=size) * a
334
+
335
+ def __new__(cls, name, a, **kwargs):
336
+ return CustomDist(
337
+ name,
338
+ a,
339
+ dist=cls.maxwell_dist,
340
+ class_name="Maxwell",
341
+ **kwargs,
342
+ )
343
+
344
+ @classmethod
345
+ def dist(cls, a, **kwargs):
346
+ return CustomDist.dist(
347
+ a,
348
+ dist=cls.maxwell_dist,
349
+ class_name="Maxwell",
350
+ **kwargs,
351
+ )