pymc-extras 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/distributions/continuous.py +3 -2
  3. pymc_extras/distributions/discrete.py +3 -1
  4. pymc_extras/inference/find_map.py +62 -17
  5. pymc_extras/inference/laplace.py +10 -7
  6. pymc_extras/statespace/core/statespace.py +191 -52
  7. pymc_extras/statespace/filters/distributions.py +15 -16
  8. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  9. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  10. pymc_extras/statespace/models/ETS.py +10 -0
  11. pymc_extras/statespace/models/SARIMAX.py +26 -5
  12. pymc_extras/statespace/models/VARMAX.py +12 -2
  13. pymc_extras/statespace/models/structural.py +18 -5
  14. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  15. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  16. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  17. pymc_extras/version.py +0 -11
  18. pymc_extras/version.txt +0 -1
  19. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  20. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  21. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  22. tests/__init__.py +0 -13
  23. tests/distributions/__init__.py +0 -19
  24. tests/distributions/test_continuous.py +0 -185
  25. tests/distributions/test_discrete.py +0 -210
  26. tests/distributions/test_discrete_markov_chain.py +0 -258
  27. tests/distributions/test_multivariate.py +0 -304
  28. tests/distributions/test_transform.py +0 -77
  29. tests/model/__init__.py +0 -0
  30. tests/model/marginal/__init__.py +0 -0
  31. tests/model/marginal/test_distributions.py +0 -132
  32. tests/model/marginal/test_graph_analysis.py +0 -182
  33. tests/model/marginal/test_marginal_model.py +0 -967
  34. tests/model/test_model_api.py +0 -38
  35. tests/statespace/__init__.py +0 -0
  36. tests/statespace/test_ETS.py +0 -411
  37. tests/statespace/test_SARIMAX.py +0 -405
  38. tests/statespace/test_VARMAX.py +0 -184
  39. tests/statespace/test_coord_assignment.py +0 -181
  40. tests/statespace/test_distributions.py +0 -270
  41. tests/statespace/test_kalman_filter.py +0 -326
  42. tests/statespace/test_representation.py +0 -175
  43. tests/statespace/test_statespace.py +0 -872
  44. tests/statespace/test_statespace_JAX.py +0 -156
  45. tests/statespace/test_structural.py +0 -836
  46. tests/statespace/utilities/__init__.py +0 -0
  47. tests/statespace/utilities/shared_fixtures.py +0 -9
  48. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  49. tests/statespace/utilities/test_helpers.py +0 -310
  50. tests/test_blackjax_smc.py +0 -222
  51. tests/test_find_map.py +0 -103
  52. tests/test_histogram_approximation.py +0 -109
  53. tests/test_laplace.py +0 -281
  54. tests/test_linearmodel.py +0 -208
  55. tests/test_model_builder.py +0 -306
  56. tests/test_pathfinder.py +0 -297
  57. tests/test_pivoted_cholesky.py +0 -24
  58. tests/test_printing.py +0 -98
  59. tests/test_prior_from_trace.py +0 -172
  60. tests/test_splines.py +0 -77
  61. tests/utils.py +0 -0
  62. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/licenses/LICENSE +0 -0
tests/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- # Copyright 2020 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
@@ -1,19 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from pymc_extras.distributions import histogram_utils
17
- from pymc_extras.distributions.histogram_utils import histogram_approximation
18
-
19
- __all__ = ["histogram_utils", "histogram_approximation"]
@@ -1,185 +0,0 @@
1
- # Copyright 2020 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import numpy as np
15
- import pymc as pm
16
-
17
- # general imports
18
- import pytest
19
- import scipy.stats.distributions as sp
20
-
21
-
22
- # test support imports from pymc
23
- from pymc.testing import (
24
- BaseTestDistributionRandom,
25
- Domain,
26
- R,
27
- Rplus,
28
- Rplusbig,
29
- assert_support_point_is_expected,
30
- check_logcdf,
31
- check_logp,
32
- seeded_scipy_distribution_builder,
33
- select_by_precision,
34
- )
35
-
36
- # the distributions to be tested
37
- from pymc_extras.distributions import Chi, GenExtreme, Maxwell
38
-
39
-
40
- class TestGenExtremeClass:
41
- """
42
- Wrapper class so that tests of experimental additions can be dropped into
43
- PyMC directly on adoption.
44
-
45
- pm.logp(GenExtreme.dist(mu=0.,sigma=1.,xi=0.5),value=-0.01)
46
- """
47
-
48
- def test_logp(self):
49
- def ref_logp(value, mu, sigma, xi):
50
- if 1 + xi * (value - mu) / sigma > 0:
51
- return sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
52
- else:
53
- return -np.inf
54
-
55
- check_logp(
56
- GenExtreme,
57
- R,
58
- {
59
- "mu": R,
60
- "sigma": Rplusbig,
61
- "xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
62
- },
63
- ref_logp,
64
- )
65
-
66
- def test_logcdf(self):
67
- def ref_logcdf(value, mu, sigma, xi):
68
- if 1 + xi * (value - mu) / sigma > 0:
69
- return sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
70
- else:
71
- return -np.inf
72
-
73
- check_logcdf(
74
- GenExtreme,
75
- R,
76
- {
77
- "mu": R,
78
- "sigma": Rplusbig,
79
- "xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
80
- },
81
- ref_logcdf,
82
- decimal=select_by_precision(float64=6, float32=2),
83
- )
84
-
85
- @pytest.mark.parametrize(
86
- "mu, sigma, xi, size, expected",
87
- [
88
- (0, 1, 0, None, 0),
89
- (1, np.arange(1, 4), 0.1, None, 1 + np.arange(1, 4) * (1.1**-0.1 - 1) / 0.1),
90
- (np.arange(5), 1, 0.1, None, np.arange(5) + (1.1**-0.1 - 1) / 0.1),
91
- (
92
- 0,
93
- 1,
94
- np.linspace(-0.2, 0.2, 6),
95
- None,
96
- ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
97
- / np.linspace(-0.2, 0.2, 6),
98
- ),
99
- (1, 2, 0.1, 5, np.full(5, 1 + 2 * (1.1**-0.1 - 1) / 0.1)),
100
- (
101
- np.arange(6),
102
- np.arange(1, 7),
103
- np.linspace(-0.2, 0.2, 6),
104
- (3, 6),
105
- np.full(
106
- (3, 6),
107
- np.arange(6)
108
- + np.arange(1, 7)
109
- * ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
110
- / np.linspace(-0.2, 0.2, 6),
111
- ),
112
- ),
113
- ],
114
- )
115
- def test_genextreme_support_point(self, mu, sigma, xi, size, expected):
116
- with pm.Model() as model:
117
- GenExtreme("x", mu=mu, sigma=sigma, xi=xi, size=size)
118
- assert_support_point_is_expected(model, expected)
119
-
120
- def test_gen_extreme_scipy_kwarg(self):
121
- dist = GenExtreme.dist(xi=1, scipy=False)
122
- assert dist.owner.inputs[-1].eval() == 1
123
-
124
- dist = GenExtreme.dist(xi=1, scipy=True)
125
- assert dist.owner.inputs[-1].eval() == -1
126
-
127
-
128
- class TestGenExtreme(BaseTestDistributionRandom):
129
- pymc_dist = GenExtreme
130
- pymc_dist_params = {"mu": 0, "sigma": 1, "xi": -0.1}
131
- expected_rv_op_params = {"mu": 0, "sigma": 1, "xi": -0.1}
132
- # Notice, using different parametrization of xi sign to scipy
133
- reference_dist_params = {"loc": 0, "scale": 1, "c": 0.1}
134
- reference_dist = seeded_scipy_distribution_builder("genextreme")
135
- tests_to_run = [
136
- "check_pymc_params_match_rv_op",
137
- "check_pymc_draws_match_reference",
138
- "check_rv_size",
139
- ]
140
-
141
-
142
- class TestChiClass:
143
- """
144
- Wrapper class so that tests of experimental additions can be dropped into
145
- PyMC directly on adoption.
146
- """
147
-
148
- def test_logp(self):
149
- check_logp(
150
- Chi,
151
- Rplus,
152
- {"nu": Rplus},
153
- lambda value, nu: sp.chi.logpdf(value, df=nu),
154
- )
155
-
156
- def test_logcdf(self):
157
- check_logcdf(
158
- Chi,
159
- Rplus,
160
- {"nu": Rplus},
161
- lambda value, nu: sp.chi.logcdf(value, df=nu),
162
- )
163
-
164
-
165
- class TestMaxwell:
166
- """
167
- Wrapper class so that tests of experimental additions can be dropped into
168
- PyMC directly on adoption.
169
- """
170
-
171
- def test_logp(self):
172
- check_logp(
173
- Maxwell,
174
- Rplus,
175
- {"a": Rplus},
176
- lambda value, a: sp.maxwell.logpdf(value, scale=a),
177
- )
178
-
179
- def test_logcdf(self):
180
- check_logcdf(
181
- Maxwell,
182
- Rplus,
183
- {"a": Rplus},
184
- lambda value, a: sp.maxwell.logcdf(value, scale=a),
185
- )
@@ -1,210 +0,0 @@
1
- # Copyright 2023 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import numpy as np
15
- import pymc as pm
16
- import pytensor
17
- import pytensor.tensor as pt
18
- import pytest
19
- import scipy.stats
20
-
21
- from pymc.logprob.utils import ParameterValueError
22
- from pymc.testing import (
23
- BaseTestDistributionRandom,
24
- Domain,
25
- I,
26
- Rplus,
27
- assert_support_point_is_expected,
28
- check_logp,
29
- discrete_random_tester,
30
- )
31
- from pytensor import config
32
-
33
- from pymc_extras.distributions import (
34
- BetaNegativeBinomial,
35
- GeneralizedPoisson,
36
- Skellam,
37
- )
38
-
39
-
40
- class TestGeneralizedPoisson:
41
- class TestRandomVariable(BaseTestDistributionRandom):
42
- pymc_dist = GeneralizedPoisson
43
- pymc_dist_params = {"mu": 4.0, "lam": 1.0}
44
- expected_rv_op_params = {"mu": 4.0, "lam": 1.0}
45
- tests_to_run = [
46
- "check_pymc_params_match_rv_op",
47
- "check_rv_size",
48
- ]
49
-
50
- def test_random_matches_poisson(self):
51
- discrete_random_tester(
52
- dist=self.pymc_dist,
53
- paramdomains={"mu": Rplus, "lam": Domain([0], edges=(None, None))},
54
- ref_rand=lambda mu, lam, size: scipy.stats.poisson.rvs(mu, size=size),
55
- )
56
-
57
- @pytest.mark.parametrize("mu", (2.5, 20, 50))
58
- def test_random_lam_expected_moments(self, mu):
59
- lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
60
- dist = self.pymc_dist.dist(mu=mu, lam=lam, size=(10_000, len(lam)))
61
- draws = dist.eval()
62
-
63
- expected_mean = mu / (1 - lam)
64
- np.testing.assert_allclose(draws.mean(0), expected_mean, rtol=1e-1)
65
-
66
- expected_std = np.sqrt(mu / (1 - lam) ** 3)
67
- np.testing.assert_allclose(draws.std(0), expected_std, rtol=1e-1)
68
-
69
- def test_logp_matches_poisson(self):
70
- # We are only checking this distribution for lambda=0 where it's equivalent to Poisson.
71
- mu = pt.scalar("mu")
72
- lam = pt.scalar("lam")
73
- value = pt.vector("value", dtype="int64")
74
-
75
- logp = pm.logp(GeneralizedPoisson.dist(mu, lam), value)
76
- logp_fn = pytensor.function([value, mu, lam], logp)
77
-
78
- test_value = np.array([0, 1, 2, 30])
79
- for test_mu in (0.01, 0.1, 0.9, 1, 1.5, 20, 100):
80
- np.testing.assert_allclose(
81
- logp_fn(test_value, test_mu, lam=0),
82
- scipy.stats.poisson.logpmf(test_value, test_mu),
83
- rtol=1e-7 if config.floatX == "float64" else 1e-5,
84
- )
85
-
86
- # Check out-of-bounds values
87
- value = pt.scalar("value")
88
- logp = pm.logp(GeneralizedPoisson.dist(mu, lam), value)
89
- logp_fn = pytensor.function([value, mu, lam], logp)
90
-
91
- logp_fn(-1, mu=5, lam=0) == -np.inf
92
- logp_fn(9, mu=5, lam=-1) == -np.inf
93
-
94
- # Check mu/lam restrictions
95
- with pytest.raises(ParameterValueError):
96
- logp_fn(1, mu=1, lam=2)
97
-
98
- with pytest.raises(ParameterValueError):
99
- logp_fn(1, mu=0, lam=0)
100
-
101
- with pytest.raises(ParameterValueError):
102
- logp_fn(1, mu=1, lam=-1)
103
-
104
- def test_logp_lam_expected_moments(self):
105
- mu = 30
106
- lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
107
- with pm.Model():
108
- x = GeneralizedPoisson("x", mu=mu, lam=lam)
109
- trace = pm.sample(chains=1, draws=10_000, random_seed=96).posterior
110
-
111
- expected_mean = mu / (1 - lam)
112
- np.testing.assert_allclose(trace["x"].mean(("chain", "draw")), expected_mean, rtol=1e-1)
113
-
114
- expected_std = np.sqrt(mu / (1 - lam) ** 3)
115
- np.testing.assert_allclose(trace["x"].std(("chain", "draw")), expected_std, rtol=1e-1)
116
-
117
- @pytest.mark.parametrize(
118
- "mu, lam, size, expected",
119
- [
120
- (50, [-0.6, 0, 0.6], None, np.floor(50 / (1 - np.array([-0.6, 0, 0.6])))),
121
- ([5, 50], -0.1, (4, 2), np.full((4, 2), np.floor(np.array([5, 50]) / 1.1))),
122
- ],
123
- )
124
- def test_moment(self, mu, lam, size, expected):
125
- with pm.Model() as model:
126
- GeneralizedPoisson("x", mu=mu, lam=lam, size=size)
127
- assert_support_point_is_expected(model, expected)
128
-
129
-
130
- class TestBetaNegativeBinomial:
131
- """
132
- Wrapper class so that tests of experimental additions can be dropped into
133
- PyMC directly on adoption.
134
- """
135
-
136
- def test_logp(self):
137
- """
138
-
139
- Beta Negative Binomial logp function test values taken from R package as
140
- there is currently no implementation in scipy.
141
- https://github.com/scipy/scipy/issues/17330
142
-
143
- The test values can be generated in R with the following code:
144
-
145
- .. code-block:: r
146
-
147
- library(extraDistr)
148
-
149
- create.test.rows <- function(alpha, beta, r, x) {
150
- logp <- dbnbinom(x, alpha, beta, r, log=TRUE)
151
- paste0("(", paste(alpha, beta, r, x, logp, sep=", "), ")")
152
- }
153
-
154
- x <- c(0, 1, 250, 5000)
155
- print(create.test.rows(1, 1, 1, x), quote=FALSE)
156
- print(create.test.rows(1, 1, 10, x), quote=FALSE)
157
- print(create.test.rows(1, 10, 1, x), quote=FALSE)
158
- print(create.test.rows(10, 1, 1, x), quote=FALSE)
159
- print(create.test.rows(10, 10, 10, x), quote=FALSE)
160
-
161
- """
162
- alpha, beta, r, value = pt.scalars("alpha", "beta", "r", "value")
163
- logp = pm.logp(BetaNegativeBinomial.dist(alpha, beta, r), value)
164
- logp_fn = pytensor.function([value, alpha, beta, r], logp)
165
-
166
- tests = [
167
- # 1, 1, 1
168
- (1, 1, 1, 0, -0.693147180559945),
169
- (1, 1, 1, 1, -1.79175946922805),
170
- (1, 1, 1, 250, -11.0548820266432),
171
- (1, 1, 1, 5000, -17.0349862828565),
172
- # 1, 1, 10
173
- (1, 1, 10, 0, -2.39789527279837),
174
- (1, 1, 10, 1, -2.58021682959232),
175
- (1, 1, 10, 250, -8.82261694534392),
176
- (1, 1, 10, 5000, -14.7359968760473),
177
- # 1, 10, 1
178
- (1, 10, 1, 0, -2.39789527279837),
179
- (1, 10, 1, 1, -2.58021682959232),
180
- (1, 10, 1, 250, -8.82261694534418),
181
- (1, 10, 1, 5000, -14.7359968760446),
182
- # 10, 1, 1
183
- (10, 1, 1, 0, -0.0953101798043248),
184
- (10, 1, 1, 1, -2.58021682959232),
185
- (10, 1, 1, 250, -43.5891148758123),
186
- (10, 1, 1, 5000, -76.2953173311091),
187
- # 10, 10, 10
188
- (10, 10, 10, 0, -5.37909807285049),
189
- (10, 10, 10, 1, -4.17512526852455),
190
- (10, 10, 10, 250, -21.781591505836),
191
- (10, 10, 10, 5000, -53.4836799634603),
192
- ]
193
- for test_alpha, test_beta, test_r, test_value, expected_logp in tests:
194
- np.testing.assert_allclose(
195
- logp_fn(test_value, test_alpha, test_beta, test_r), expected_logp
196
- )
197
-
198
-
199
- class TestSkellam:
200
- def test_logp(self):
201
- # Scipy Skellam underflows to -inf earlier than PyMC
202
- Rplus_small = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 10, np.inf])
203
- # Suppress warnings coming from Scipy logpmf implementation
204
- with np.errstate(divide="ignore", invalid="ignore"):
205
- check_logp(
206
- Skellam,
207
- I,
208
- {"mu1": Rplus_small, "mu2": Rplus_small},
209
- lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2),
210
- )
@@ -1,258 +0,0 @@
1
- import numpy as np
2
- import pymc as pm
3
-
4
- # general imports
5
- import pytensor.tensor as pt
6
- import pytest
7
-
8
- from pymc.distributions import Categorical
9
- from pymc.distributions.shape_utils import change_dist_size
10
- from pymc.logprob.utils import ParameterValueError
11
- from pymc.sampling.mcmc import assign_step_methods
12
-
13
- from pymc_extras.distributions.timeseries import (
14
- DiscreteMarkovChain,
15
- DiscreteMarkovChainGibbsMetropolis,
16
- )
17
-
18
-
19
- def transition_probability_tests(steps, n_states, n_lags, n_draws, atol):
20
- P = np.full((n_states,) * (n_lags + 1), 1 / n_states)
21
- x0 = pm.Categorical.dist(p=np.ones(n_states) / n_states)
22
-
23
- chain = DiscreteMarkovChain.dist(
24
- P=pt.as_tensor_variable(P), init_dist=x0, steps=steps, n_lags=n_lags
25
- )
26
-
27
- draws = pm.draw(chain, n_draws, random_seed=172)
28
-
29
- # Test x0 is uniform over n_states
30
- for i in range(n_lags):
31
- assert np.allclose(
32
- np.histogram(draws[:, ..., i], bins=n_states)[0] / n_draws, 1 / n_states, atol=atol
33
- )
34
-
35
- n_grams = [[tuple(row[i : i + n_lags + 1]) for i in range(len(row) - n_lags)] for row in draws]
36
- freq_table = np.zeros((n_states,) * (n_lags + 1))
37
-
38
- for row in n_grams:
39
- for ngram in row:
40
- freq_table[ngram] += 1
41
- freq_table /= freq_table.sum(axis=-1)[:, None]
42
-
43
- # Test continuation probabilities match P
44
- assert np.allclose(P, freq_table, atol=atol)
45
-
46
-
47
- class TestDiscreteMarkovRV:
48
- def test_fail_if_P_not_square(self):
49
- P = pt.eye(3, 2)
50
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
51
-
52
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
53
- with pytest.raises(ParameterValueError):
54
- pm.logp(chain, np.zeros((3,))).eval()
55
-
56
- def test_fail_if_P_not_valid(self):
57
- P = pt.zeros((3, 3))
58
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
59
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
60
- with pytest.raises(ParameterValueError):
61
- pm.logp(chain, np.zeros((3,))).eval()
62
-
63
- def test_high_dimensional_P(self):
64
- P = pm.Dirichlet.dist(a=pt.ones(3), size=(3, 3, 3))
65
- n_lags = 3
66
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
67
- chain = DiscreteMarkovChain.dist(P=P, steps=10, init_dist=x0, n_lags=n_lags)
68
- draws = pm.draw(chain, 10)
69
- logp = pm.logp(chain, draws)
70
-
71
- def test_default_init_dist_warns_user(self):
72
- P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
73
-
74
- with pytest.warns(UserWarning):
75
- DiscreteMarkovChain.dist(P=P, steps=3)
76
-
77
- def test_logp_shape(self):
78
- P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
79
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
80
-
81
- # Test with steps
82
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
83
- draws = pm.draw(chain, 5)
84
- logp = pm.logp(chain, draws).eval()
85
-
86
- assert logp.shape == (5,)
87
-
88
- # Test with shape
89
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, shape=(3,))
90
- draws = pm.draw(chain, 5)
91
- logp = pm.logp(chain, draws).eval()
92
-
93
- assert logp.shape == (5,)
94
-
95
- def test_logp_with_default_init_dist(self):
96
- P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
97
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
98
-
99
- value = np.array([0, 1, 2])
100
- logp_expected = np.log((1 / 3) * 0.5 * 0.3)
101
-
102
- # Test dist directly
103
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
104
- logp_eval = pm.logp(chain, value).eval()
105
- np.testing.assert_allclose(logp_eval, logp_expected, rtol=1e-6)
106
-
107
- # Test via Model
108
- with pm.Model() as m:
109
- DiscreteMarkovChain("chain", P=P, init_dist=x0, steps=3)
110
- model_logp_eval = m.compile_logp()({"chain": value})
111
- np.testing.assert_allclose(model_logp_eval, logp_expected, rtol=1e-6)
112
-
113
- def test_logp_with_user_defined_init_dist(self):
114
- P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
115
- x0 = pm.Categorical.dist(p=[0.2, 0.6, 0.2])
116
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
117
-
118
- logp = pm.logp(chain, [0, 1, 2]).eval()
119
- assert logp == np.log(0.2 * 0.5 * 0.3)
120
-
121
- def test_moment_function(self):
122
- P_np = np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]])
123
-
124
- x0_np = np.array([0, 1, 0])
125
-
126
- P = pt.as_tensor_variable(P_np)
127
- x0 = pm.Categorical.dist(p=x0_np.tolist())
128
- n_steps = 3
129
-
130
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=n_steps)
131
-
132
- chain_np = np.empty(shape=n_steps + 1, dtype="int8")
133
- chain_np[0] = np.argmax(x0_np)
134
- for i in range(n_steps):
135
- state = chain_np[i]
136
- chain_np[i + 1] = np.argmax(P_np[state])
137
-
138
- dmc_chain = pm.distributions.distribution.support_point(chain).eval()
139
-
140
- assert np.allclose(dmc_chain, chain_np)
141
-
142
- def test_define_steps_via_shape_arg(self):
143
- P = pt.full((3, 3), 1 / 3)
144
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
145
-
146
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, shape=(3,))
147
- assert chain.eval().shape == (3,)
148
-
149
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, shape=(3, 2))
150
- assert chain.eval().shape == (3, 2)
151
-
152
- def test_define_steps_via_dim_arg(self):
153
- coords = {"steps": [1, 2, 3]}
154
-
155
- with pm.Model(coords=coords):
156
- P = pt.full((3, 3), 1 / 3)
157
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
158
-
159
- chain = DiscreteMarkovChain("chain", P=P, init_dist=x0, dims=["steps"])
160
-
161
- assert chain.eval().shape == (3,)
162
-
163
- def test_dims_when_steps_are_defined(self):
164
- coords = {"steps": [1, 2, 3, 4]}
165
-
166
- with pm.Model(coords=coords):
167
- P = pt.full((3, 3), 1 / 3)
168
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
169
-
170
- chain = DiscreteMarkovChain("chain", P=P, steps=3, init_dist=x0, dims=["steps"])
171
-
172
- assert chain.eval().shape == (4,)
173
-
174
- def test_multiple_dims_with_steps(self):
175
- coords = {"steps": [1, 2, 3], "mc_chains": [1, 2, 3]}
176
-
177
- with pm.Model(coords=coords):
178
- P = pt.full((3, 3), 1 / 3)
179
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
180
-
181
- chain = DiscreteMarkovChain(
182
- "chain", P=P, steps=2, init_dist=x0, dims=["steps", "mc_chains"]
183
- )
184
-
185
- assert chain.eval().shape == (3, 3)
186
-
187
- def test_mutiple_dims_with_steps_and_init_dist(self):
188
- coords = {"steps": [1, 2, 3], "mc_chains": [1, 2, 3]}
189
-
190
- with pm.Model(coords=coords):
191
- P = pt.full((3, 3), 1 / 3)
192
- x0 = pm.Categorical.dist(p=[0.1, 0.1, 0.8], size=(3,))
193
- chain = DiscreteMarkovChain(
194
- "chain", P=P, init_dist=x0, steps=2, dims=["steps", "mc_chains"]
195
- )
196
-
197
- assert chain.eval().shape == (3, 3)
198
-
199
- def test_multiple_lags_with_data(self):
200
- with pm.Model():
201
- P = pt.full((3, 3, 3), 1 / 3)
202
- x0 = pm.Categorical.dist(p=[0.1, 0.1, 0.8], size=2)
203
- data = pm.draw(x0, 100)
204
-
205
- chain = DiscreteMarkovChain("chain", P=P, init_dist=x0, n_lags=2, observed=data)
206
-
207
- assert chain.eval().shape == (100, 2)
208
-
209
- def test_random_draws(self):
210
- transition_probability_tests(steps=3, n_states=2, n_lags=1, n_draws=2500, atol=0.05)
211
- transition_probability_tests(steps=3, n_states=2, n_lags=3, n_draws=7500, atol=0.05)
212
-
213
- def test_change_size_univariate(self):
214
- P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
215
- x0 = pm.Categorical.dist(p=np.ones(3) / 3)
216
-
217
- chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, shape=(100, 5))
218
-
219
- new_rw = change_dist_size(chain, new_size=(7,))
220
- assert tuple(new_rw.shape.eval()) == (7, 5)
221
-
222
- new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
223
- assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)
224
-
225
- def test_mcmc_sampling(self):
226
- with pm.Model(coords={"step": range(100)}) as model:
227
- init_dist = Categorical.dist(p=[0.5, 0.5])
228
- markov_chain = DiscreteMarkovChain(
229
- "markov_chain",
230
- P=[[0.1, 0.9], [0.1, 0.9]],
231
- init_dist=init_dist,
232
- shape=(100,),
233
- dims="step",
234
- )
235
-
236
- _, assigned_step_methods = assign_step_methods(model)
237
- assert assigned_step_methods[DiscreteMarkovChainGibbsMetropolis] == [
238
- model.rvs_to_values[markov_chain]
239
- ]
240
-
241
- # Sampler needs no tuning
242
- idata = pm.sample(
243
- tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False
244
- )
245
-
246
- np.testing.assert_allclose(
247
- idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")),
248
- 0.5,
249
- atol=0.05,
250
- )
251
-
252
- np.testing.assert_allclose(
253
- idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")),
254
- 0.9,
255
- atol=0.05,
256
- )
257
-
258
- assert pm.stats.ess(idata, method="tail").min() > 950