pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/deserialize.py +224 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/inference/find_map.py +62 -17
  6. pymc_extras/inference/laplace.py +10 -7
  7. pymc_extras/prior.py +1356 -0
  8. pymc_extras/statespace/core/statespace.py +191 -52
  9. pymc_extras/statespace/filters/distributions.py +15 -16
  10. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  11. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  12. pymc_extras/statespace/models/ETS.py +10 -0
  13. pymc_extras/statespace/models/SARIMAX.py +26 -5
  14. pymc_extras/statespace/models/VARMAX.py +12 -2
  15. pymc_extras/statespace/models/structural.py +18 -5
  16. pymc_extras-0.2.7.dist-info/METADATA +321 -0
  17. pymc_extras-0.2.7.dist-info/RECORD +66 -0
  18. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
  19. pymc_extras/utils/pivoted_cholesky.py +0 -69
  20. pymc_extras/version.py +0 -11
  21. pymc_extras/version.txt +0 -1
  22. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  23. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  24. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  25. tests/__init__.py +0 -13
  26. tests/distributions/__init__.py +0 -19
  27. tests/distributions/test_continuous.py +0 -185
  28. tests/distributions/test_discrete.py +0 -210
  29. tests/distributions/test_discrete_markov_chain.py +0 -258
  30. tests/distributions/test_multivariate.py +0 -304
  31. tests/distributions/test_transform.py +0 -77
  32. tests/model/__init__.py +0 -0
  33. tests/model/marginal/__init__.py +0 -0
  34. tests/model/marginal/test_distributions.py +0 -132
  35. tests/model/marginal/test_graph_analysis.py +0 -182
  36. tests/model/marginal/test_marginal_model.py +0 -967
  37. tests/model/test_model_api.py +0 -38
  38. tests/statespace/__init__.py +0 -0
  39. tests/statespace/test_ETS.py +0 -411
  40. tests/statespace/test_SARIMAX.py +0 -405
  41. tests/statespace/test_VARMAX.py +0 -184
  42. tests/statespace/test_coord_assignment.py +0 -181
  43. tests/statespace/test_distributions.py +0 -270
  44. tests/statespace/test_kalman_filter.py +0 -326
  45. tests/statespace/test_representation.py +0 -175
  46. tests/statespace/test_statespace.py +0 -872
  47. tests/statespace/test_statespace_JAX.py +0 -156
  48. tests/statespace/test_structural.py +0 -836
  49. tests/statespace/utilities/__init__.py +0 -0
  50. tests/statespace/utilities/shared_fixtures.py +0 -9
  51. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  52. tests/statespace/utilities/test_helpers.py +0 -310
  53. tests/test_blackjax_smc.py +0 -222
  54. tests/test_find_map.py +0 -103
  55. tests/test_histogram_approximation.py +0 -109
  56. tests/test_laplace.py +0 -281
  57. tests/test_linearmodel.py +0 -208
  58. tests/test_model_builder.py +0 -306
  59. tests/test_pathfinder.py +0 -297
  60. tests/test_pivoted_cholesky.py +0 -24
  61. tests/test_printing.py +0 -98
  62. tests/test_prior_from_trace.py +0 -172
  63. tests/test_splines.py +0 -77
  64. tests/utils.py +0 -0
  65. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,132 +0,0 @@
1
- import numpy as np
2
- import pymc as pm
3
- import pytest
4
-
5
- from pymc.logprob.abstract import _logprob
6
- from pytensor import tensor as pt
7
- from scipy.stats import norm
8
-
9
- from pymc_extras import marginalize
10
- from pymc_extras.distributions import DiscreteMarkovChain
11
- from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
12
-
13
-
14
- def test_marginalized_bernoulli_logp():
15
- """Test logp of IR TestFiniteMarginalDiscreteRV directly"""
16
- mu = pt.vector("mu")
17
-
18
- idx = pm.Bernoulli.dist(0.7, name="idx")
19
- y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
20
- marginal_rv_node = MarginalFiniteDiscreteRV(
21
- [mu],
22
- [idx, y],
23
- dims_connections=(((),),),
24
- dims=(),
25
- )(mu)[0].owner
26
-
27
- y_vv = y.clone()
28
- (logp,) = _logprob(
29
- marginal_rv_node.op,
30
- (y_vv,),
31
- *marginal_rv_node.inputs,
32
- )
33
-
34
- ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv)
35
- np.testing.assert_almost_equal(
36
- logp.eval({mu: [-1, 1], y_vv: 2}),
37
- ref_logp.eval({mu: [-1, 1], y_vv: 2}),
38
- )
39
-
40
-
41
- @pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}")
42
- @pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}")
43
- def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
44
- if batch_chain and not batch_emission:
45
- pytest.skip("Redundant implicit combination")
46
-
47
- with pm.Model() as m:
48
- P = [[0, 1], [1, 0]]
49
- init_dist = pm.Categorical.dist(p=[1, 0])
50
- chain = DiscreteMarkovChain(
51
- "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
52
- )
53
- emission = pm.Normal(
54
- "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
55
- )
56
-
57
- marginal_m = marginalize(m, [chain])
58
- logp_fn = marginal_m.compile_logp()
59
-
60
- test_value = np.array([-1, 1, -1, 1])
61
- expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
62
- if batch_emission:
63
- test_value = np.broadcast_to(test_value, (3, 4))
64
- expected_logp *= 3
65
- np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
66
-
67
-
68
- @pytest.mark.parametrize(
69
- "categorical_emission",
70
- [False, True],
71
- )
72
- def test_marginalized_hmm_categorical_emission(categorical_emission):
73
- """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
74
- with pm.Model() as m:
75
- P = np.array([[0.5, 0.5], [0.3, 0.7]])
76
- init_dist = pm.Categorical.dist(p=[0.375, 0.625])
77
- chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
78
- if categorical_emission:
79
- emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
80
- else:
81
- emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
82
- marginal_m = marginalize(m, [chain])
83
-
84
- test_value = np.array([0, 0, 1])
85
- expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
86
- logp_fn = marginal_m.compile_logp()
87
- np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
88
-
89
-
90
- @pytest.mark.parametrize("batch_chain", (False, True))
91
- @pytest.mark.parametrize("batch_emission1", (False, True))
92
- @pytest.mark.parametrize("batch_emission2", (False, True))
93
- def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2):
94
- chain_shape = (3, 1, 4) if batch_chain else (4,)
95
- emission1_shape = (
96
- (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
97
- )
98
- emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
99
- with pm.Model() as m:
100
- P = [[0, 1], [1, 0]]
101
- init_dist = pm.Categorical.dist(p=[1, 0])
102
- chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
103
- emission_1 = pm.Normal(
104
- "emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape
105
- )
106
-
107
- emission2_mu = (1 - chain) * 2 - 1
108
- if batch_emission2:
109
- emission2_mu = emission2_mu[..., None]
110
- emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
111
-
112
- marginal_m = marginalize(m, [chain])
113
-
114
- with pytest.warns(UserWarning, match="multiple dependent variables"):
115
- logp_fn = marginal_m.compile_logp(sum=False)
116
-
117
- test_value = np.array([-1, 1, -1, 1])
118
- multiplier = 2 + batch_emission1 + batch_emission2
119
- if batch_chain:
120
- multiplier *= 3
121
- expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
122
-
123
- test_value = np.broadcast_to(test_value, chain_shape)
124
- test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape)
125
- if batch_emission2:
126
- test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape)
127
- else:
128
- test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
129
- test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
130
- res_logp, dummy_logp = logp_fn(test_point)
131
- assert res_logp.shape == ((1, 3) if batch_chain else ())
132
- np.testing.assert_allclose(res_logp.sum(), expected_logp)
@@ -1,182 +0,0 @@
1
- import pytensor.tensor as pt
2
- import pytest
3
-
4
- from pymc.distributions import CustomDist
5
- from pytensor.tensor.type_other import NoneTypeT
6
-
7
- from pymc_extras.model.marginal.graph_analysis import (
8
- is_conditional_dependent,
9
- subgraph_batch_dim_connection,
10
- )
11
-
12
-
13
- def test_is_conditional_dependent_static_shape():
14
- """Test that we don't consider dependencies through "constant" shape Ops"""
15
- x1 = pt.matrix("x1", shape=(None, 5))
16
- y1 = pt.random.normal(size=pt.shape(x1))
17
- assert is_conditional_dependent(y1, x1, [x1, y1])
18
-
19
- x2 = pt.matrix("x2", shape=(9, 5))
20
- y2 = pt.random.normal(size=pt.shape(x2))
21
- assert not is_conditional_dependent(y2, x2, [x2, y2])
22
-
23
-
24
- class TestSubgraphBatchDimConnection:
25
- def test_dimshuffle(self):
26
- inp = pt.tensor(shape=(5, 1, 4, 3))
27
- out1 = pt.matrix_transpose(inp)
28
- out2 = pt.expand_dims(inp, 1)
29
- out3 = pt.squeeze(inp)
30
- [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
31
- assert dims1 == (0, 1, 3, 2)
32
- assert dims2 == (0, None, 1, 2, 3)
33
- assert dims3 == (0, 2, 3)
34
-
35
- def test_careduce(self):
36
- inp = pt.tensor(shape=(4, 3, 2))
37
-
38
- out = pt.sum(inp[:, None], axis=(1,))
39
- [dims] = subgraph_batch_dim_connection(inp, [out])
40
- assert dims == (0, 1, 2)
41
-
42
- invalid_out = pt.sum(inp, axis=(1,))
43
- with pytest.raises(ValueError, match="Use of known dimensions"):
44
- subgraph_batch_dim_connection(inp, [invalid_out])
45
-
46
- def test_subtensor(self):
47
- inp = pt.tensor(shape=(4, 3, 2))
48
-
49
- invalid_out = inp[0, :1]
50
- with pytest.raises(
51
- ValueError,
52
- match="Partial slicing or indexing of known dimensions not supported",
53
- ):
54
- subgraph_batch_dim_connection(inp, [invalid_out])
55
-
56
- # If we are selecting dummy / unknown dimensions that's fine
57
- valid_out = pt.expand_dims(inp, (0, 1))[0, :1]
58
- [dims] = subgraph_batch_dim_connection(inp, [valid_out])
59
- assert dims == (None, 0, 1, 2)
60
-
61
- def test_advanced_subtensor_value(self):
62
- inp = pt.tensor(shape=(2, 4))
63
- intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5))
64
-
65
- # Index on an unlabled dim introduced by broadcasting with zeros
66
- out = intermediate_out[:, [0, 0, 1, 2]]
67
- [dims] = subgraph_batch_dim_connection(inp, [out])
68
- assert dims == (0, None, 1, None)
69
-
70
- # Indexing that introduces more dimensions
71
- out = intermediate_out[:, [[0, 0], [1, 2]], :]
72
- [dims] = subgraph_batch_dim_connection(inp, [out])
73
- assert dims == (0, None, None, 1, None)
74
-
75
- # Special case where advanced dims are moved to the front of the output
76
- out = intermediate_out[:, [0, 0, 1, 2], :, 0]
77
- [dims] = subgraph_batch_dim_connection(inp, [out])
78
- assert dims == (None, 0, 1)
79
-
80
- # Indexing on a labeled dim fails
81
- out = intermediate_out[:, :, [0, 0, 1, 2]]
82
- with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"):
83
- subgraph_batch_dim_connection(inp, [out])
84
-
85
- def test_advanced_subtensor_key(self):
86
- inp = pt.tensor(shape=(5, 5), dtype=int)
87
- base = pt.zeros((2, 3, 4))
88
-
89
- out = base[inp]
90
- [dims] = subgraph_batch_dim_connection(inp, [out])
91
- assert dims == (0, 1, None, None)
92
-
93
- out = base[:, :, inp]
94
- [dims] = subgraph_batch_dim_connection(inp, [out])
95
- assert dims == (
96
- None,
97
- None,
98
- 0,
99
- 1,
100
- )
101
-
102
- out = base[1:, 0, inp]
103
- [dims] = subgraph_batch_dim_connection(inp, [out])
104
- assert dims == (None, 0, 1)
105
-
106
- # Special case where advanced dims are moved to the front of the output
107
- out = base[0, :, inp]
108
- [dims] = subgraph_batch_dim_connection(inp, [out])
109
- assert dims == (0, 1, None)
110
-
111
- # Mix keys dimensions
112
- out = base[:, inp, inp.T]
113
- with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
114
- subgraph_batch_dim_connection(inp, [out])
115
-
116
- def test_elemwise(self):
117
- inp = pt.tensor(shape=(5, 5))
118
-
119
- out = inp + inp
120
- [dims] = subgraph_batch_dim_connection(inp, [out])
121
- assert dims == (0, 1)
122
-
123
- out = inp + inp.T
124
- with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
125
- subgraph_batch_dim_connection(inp, [out])
126
-
127
- out = inp[None, :, None, :] + inp[:, None, :, None]
128
- with pytest.raises(
129
- ValueError, match="Same known dimension used in different axis after broadcasting"
130
- ):
131
- subgraph_batch_dim_connection(inp, [out])
132
-
133
- def test_blockwise(self):
134
- inp = pt.tensor(shape=(5, 4))
135
-
136
- invalid_out = inp @ pt.ones((4, 3))
137
- with pytest.raises(ValueError, match="Use of known dimensions"):
138
- subgraph_batch_dim_connection(inp, [invalid_out])
139
-
140
- out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3))
141
- [dims] = subgraph_batch_dim_connection(inp, [out])
142
- assert dims == (0, 1, None, None)
143
-
144
- def test_random_variable(self):
145
- inp = pt.tensor(shape=(5, 4, 3))
146
-
147
- out1 = pt.random.normal(loc=inp)
148
- out2 = pt.random.categorical(p=inp[..., None])
149
- out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1))
150
- [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
151
- assert dims1 == (0, 1, 2)
152
- assert dims2 == (0, 1, 2)
153
- assert dims3 == (0, 1, 2, None)
154
-
155
- invalid_out = pt.random.categorical(p=inp)
156
- with pytest.raises(ValueError, match="Use of known dimensions"):
157
- subgraph_batch_dim_connection(inp, [invalid_out])
158
-
159
- invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
160
- with pytest.raises(ValueError, match="Use of known dimensions"):
161
- subgraph_batch_dim_connection(inp, [invalid_out])
162
-
163
- def test_symbolic_random_variable(self):
164
- inp = pt.tensor(shape=(4, 3, 2))
165
-
166
- # Test univariate
167
- out = CustomDist.dist(
168
- inp,
169
- dist=lambda mu, size: pt.random.normal(loc=mu, size=size),
170
- )
171
- [dims] = subgraph_batch_dim_connection(inp, [out])
172
- assert dims == (0, 1, 2)
173
-
174
- # Test multivariate
175
- def dist(mu, size):
176
- if isinstance(size.type, NoneTypeT):
177
- size = mu.shape
178
- return pt.random.normal(loc=mu[..., None], size=(*size, 2))
179
-
180
- out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)")
181
- [dims] = subgraph_batch_dim_connection(inp, [out])
182
- assert dims == (0, 1, 2, None)