pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/deserialize.py +224 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/laplace.py +10 -7
- pymc_extras/prior.py +1356 -0
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras-0.2.7.dist-info/METADATA +321 -0
- pymc_extras-0.2.7.dist-info/RECORD +66 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
- pymc_extras/utils/pivoted_cholesky.py +0 -69
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.5.dist-info/METADATA +0 -112
- pymc_extras-0.2.5.dist-info/RECORD +0 -108
- pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/distributions/test_transform.py +0 -77
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -181
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -281
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -297
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pymc as pm
|
|
3
|
-
import pytest
|
|
4
|
-
|
|
5
|
-
from pymc.logprob.abstract import _logprob
|
|
6
|
-
from pytensor import tensor as pt
|
|
7
|
-
from scipy.stats import norm
|
|
8
|
-
|
|
9
|
-
from pymc_extras import marginalize
|
|
10
|
-
from pymc_extras.distributions import DiscreteMarkovChain
|
|
11
|
-
from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def test_marginalized_bernoulli_logp():
|
|
15
|
-
"""Test logp of IR TestFiniteMarginalDiscreteRV directly"""
|
|
16
|
-
mu = pt.vector("mu")
|
|
17
|
-
|
|
18
|
-
idx = pm.Bernoulli.dist(0.7, name="idx")
|
|
19
|
-
y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
|
|
20
|
-
marginal_rv_node = MarginalFiniteDiscreteRV(
|
|
21
|
-
[mu],
|
|
22
|
-
[idx, y],
|
|
23
|
-
dims_connections=(((),),),
|
|
24
|
-
dims=(),
|
|
25
|
-
)(mu)[0].owner
|
|
26
|
-
|
|
27
|
-
y_vv = y.clone()
|
|
28
|
-
(logp,) = _logprob(
|
|
29
|
-
marginal_rv_node.op,
|
|
30
|
-
(y_vv,),
|
|
31
|
-
*marginal_rv_node.inputs,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
ref_logp = pm.logp(pm.NormalMixture.dist(w=[0.3, 0.7], mu=mu, sigma=1.0), y_vv)
|
|
35
|
-
np.testing.assert_almost_equal(
|
|
36
|
-
logp.eval({mu: [-1, 1], y_vv: 2}),
|
|
37
|
-
ref_logp.eval({mu: [-1, 1], y_vv: 2}),
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}")
|
|
42
|
-
@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}")
|
|
43
|
-
def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
44
|
-
if batch_chain and not batch_emission:
|
|
45
|
-
pytest.skip("Redundant implicit combination")
|
|
46
|
-
|
|
47
|
-
with pm.Model() as m:
|
|
48
|
-
P = [[0, 1], [1, 0]]
|
|
49
|
-
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
50
|
-
chain = DiscreteMarkovChain(
|
|
51
|
-
"chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None
|
|
52
|
-
)
|
|
53
|
-
emission = pm.Normal(
|
|
54
|
-
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
marginal_m = marginalize(m, [chain])
|
|
58
|
-
logp_fn = marginal_m.compile_logp()
|
|
59
|
-
|
|
60
|
-
test_value = np.array([-1, 1, -1, 1])
|
|
61
|
-
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
|
|
62
|
-
if batch_emission:
|
|
63
|
-
test_value = np.broadcast_to(test_value, (3, 4))
|
|
64
|
-
expected_logp *= 3
|
|
65
|
-
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@pytest.mark.parametrize(
|
|
69
|
-
"categorical_emission",
|
|
70
|
-
[False, True],
|
|
71
|
-
)
|
|
72
|
-
def test_marginalized_hmm_categorical_emission(categorical_emission):
|
|
73
|
-
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
|
|
74
|
-
with pm.Model() as m:
|
|
75
|
-
P = np.array([[0.5, 0.5], [0.3, 0.7]])
|
|
76
|
-
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
|
|
77
|
-
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
|
|
78
|
-
if categorical_emission:
|
|
79
|
-
emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
|
|
80
|
-
else:
|
|
81
|
-
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
|
|
82
|
-
marginal_m = marginalize(m, [chain])
|
|
83
|
-
|
|
84
|
-
test_value = np.array([0, 0, 1])
|
|
85
|
-
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
|
|
86
|
-
logp_fn = marginal_m.compile_logp()
|
|
87
|
-
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@pytest.mark.parametrize("batch_chain", (False, True))
|
|
91
|
-
@pytest.mark.parametrize("batch_emission1", (False, True))
|
|
92
|
-
@pytest.mark.parametrize("batch_emission2", (False, True))
|
|
93
|
-
def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2):
|
|
94
|
-
chain_shape = (3, 1, 4) if batch_chain else (4,)
|
|
95
|
-
emission1_shape = (
|
|
96
|
-
(2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
|
|
97
|
-
)
|
|
98
|
-
emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
|
|
99
|
-
with pm.Model() as m:
|
|
100
|
-
P = [[0, 1], [1, 0]]
|
|
101
|
-
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
102
|
-
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
|
|
103
|
-
emission_1 = pm.Normal(
|
|
104
|
-
"emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
emission2_mu = (1 - chain) * 2 - 1
|
|
108
|
-
if batch_emission2:
|
|
109
|
-
emission2_mu = emission2_mu[..., None]
|
|
110
|
-
emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
|
|
111
|
-
|
|
112
|
-
marginal_m = marginalize(m, [chain])
|
|
113
|
-
|
|
114
|
-
with pytest.warns(UserWarning, match="multiple dependent variables"):
|
|
115
|
-
logp_fn = marginal_m.compile_logp(sum=False)
|
|
116
|
-
|
|
117
|
-
test_value = np.array([-1, 1, -1, 1])
|
|
118
|
-
multiplier = 2 + batch_emission1 + batch_emission2
|
|
119
|
-
if batch_chain:
|
|
120
|
-
multiplier *= 3
|
|
121
|
-
expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
|
|
122
|
-
|
|
123
|
-
test_value = np.broadcast_to(test_value, chain_shape)
|
|
124
|
-
test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape)
|
|
125
|
-
if batch_emission2:
|
|
126
|
-
test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape)
|
|
127
|
-
else:
|
|
128
|
-
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
|
|
129
|
-
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
|
|
130
|
-
res_logp, dummy_logp = logp_fn(test_point)
|
|
131
|
-
assert res_logp.shape == ((1, 3) if batch_chain else ())
|
|
132
|
-
np.testing.assert_allclose(res_logp.sum(), expected_logp)
|
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
import pytensor.tensor as pt
|
|
2
|
-
import pytest
|
|
3
|
-
|
|
4
|
-
from pymc.distributions import CustomDist
|
|
5
|
-
from pytensor.tensor.type_other import NoneTypeT
|
|
6
|
-
|
|
7
|
-
from pymc_extras.model.marginal.graph_analysis import (
|
|
8
|
-
is_conditional_dependent,
|
|
9
|
-
subgraph_batch_dim_connection,
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def test_is_conditional_dependent_static_shape():
|
|
14
|
-
"""Test that we don't consider dependencies through "constant" shape Ops"""
|
|
15
|
-
x1 = pt.matrix("x1", shape=(None, 5))
|
|
16
|
-
y1 = pt.random.normal(size=pt.shape(x1))
|
|
17
|
-
assert is_conditional_dependent(y1, x1, [x1, y1])
|
|
18
|
-
|
|
19
|
-
x2 = pt.matrix("x2", shape=(9, 5))
|
|
20
|
-
y2 = pt.random.normal(size=pt.shape(x2))
|
|
21
|
-
assert not is_conditional_dependent(y2, x2, [x2, y2])
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class TestSubgraphBatchDimConnection:
|
|
25
|
-
def test_dimshuffle(self):
|
|
26
|
-
inp = pt.tensor(shape=(5, 1, 4, 3))
|
|
27
|
-
out1 = pt.matrix_transpose(inp)
|
|
28
|
-
out2 = pt.expand_dims(inp, 1)
|
|
29
|
-
out3 = pt.squeeze(inp)
|
|
30
|
-
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
|
|
31
|
-
assert dims1 == (0, 1, 3, 2)
|
|
32
|
-
assert dims2 == (0, None, 1, 2, 3)
|
|
33
|
-
assert dims3 == (0, 2, 3)
|
|
34
|
-
|
|
35
|
-
def test_careduce(self):
|
|
36
|
-
inp = pt.tensor(shape=(4, 3, 2))
|
|
37
|
-
|
|
38
|
-
out = pt.sum(inp[:, None], axis=(1,))
|
|
39
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
40
|
-
assert dims == (0, 1, 2)
|
|
41
|
-
|
|
42
|
-
invalid_out = pt.sum(inp, axis=(1,))
|
|
43
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
44
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
45
|
-
|
|
46
|
-
def test_subtensor(self):
|
|
47
|
-
inp = pt.tensor(shape=(4, 3, 2))
|
|
48
|
-
|
|
49
|
-
invalid_out = inp[0, :1]
|
|
50
|
-
with pytest.raises(
|
|
51
|
-
ValueError,
|
|
52
|
-
match="Partial slicing or indexing of known dimensions not supported",
|
|
53
|
-
):
|
|
54
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
55
|
-
|
|
56
|
-
# If we are selecting dummy / unknown dimensions that's fine
|
|
57
|
-
valid_out = pt.expand_dims(inp, (0, 1))[0, :1]
|
|
58
|
-
[dims] = subgraph_batch_dim_connection(inp, [valid_out])
|
|
59
|
-
assert dims == (None, 0, 1, 2)
|
|
60
|
-
|
|
61
|
-
def test_advanced_subtensor_value(self):
|
|
62
|
-
inp = pt.tensor(shape=(2, 4))
|
|
63
|
-
intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5))
|
|
64
|
-
|
|
65
|
-
# Index on an unlabled dim introduced by broadcasting with zeros
|
|
66
|
-
out = intermediate_out[:, [0, 0, 1, 2]]
|
|
67
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
68
|
-
assert dims == (0, None, 1, None)
|
|
69
|
-
|
|
70
|
-
# Indexing that introduces more dimensions
|
|
71
|
-
out = intermediate_out[:, [[0, 0], [1, 2]], :]
|
|
72
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
73
|
-
assert dims == (0, None, None, 1, None)
|
|
74
|
-
|
|
75
|
-
# Special case where advanced dims are moved to the front of the output
|
|
76
|
-
out = intermediate_out[:, [0, 0, 1, 2], :, 0]
|
|
77
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
78
|
-
assert dims == (None, 0, 1)
|
|
79
|
-
|
|
80
|
-
# Indexing on a labeled dim fails
|
|
81
|
-
out = intermediate_out[:, :, [0, 0, 1, 2]]
|
|
82
|
-
with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"):
|
|
83
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
84
|
-
|
|
85
|
-
def test_advanced_subtensor_key(self):
|
|
86
|
-
inp = pt.tensor(shape=(5, 5), dtype=int)
|
|
87
|
-
base = pt.zeros((2, 3, 4))
|
|
88
|
-
|
|
89
|
-
out = base[inp]
|
|
90
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
91
|
-
assert dims == (0, 1, None, None)
|
|
92
|
-
|
|
93
|
-
out = base[:, :, inp]
|
|
94
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
95
|
-
assert dims == (
|
|
96
|
-
None,
|
|
97
|
-
None,
|
|
98
|
-
0,
|
|
99
|
-
1,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
out = base[1:, 0, inp]
|
|
103
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
104
|
-
assert dims == (None, 0, 1)
|
|
105
|
-
|
|
106
|
-
# Special case where advanced dims are moved to the front of the output
|
|
107
|
-
out = base[0, :, inp]
|
|
108
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
109
|
-
assert dims == (0, 1, None)
|
|
110
|
-
|
|
111
|
-
# Mix keys dimensions
|
|
112
|
-
out = base[:, inp, inp.T]
|
|
113
|
-
with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
|
|
114
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
115
|
-
|
|
116
|
-
def test_elemwise(self):
|
|
117
|
-
inp = pt.tensor(shape=(5, 5))
|
|
118
|
-
|
|
119
|
-
out = inp + inp
|
|
120
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
121
|
-
assert dims == (0, 1)
|
|
122
|
-
|
|
123
|
-
out = inp + inp.T
|
|
124
|
-
with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
|
|
125
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
126
|
-
|
|
127
|
-
out = inp[None, :, None, :] + inp[:, None, :, None]
|
|
128
|
-
with pytest.raises(
|
|
129
|
-
ValueError, match="Same known dimension used in different axis after broadcasting"
|
|
130
|
-
):
|
|
131
|
-
subgraph_batch_dim_connection(inp, [out])
|
|
132
|
-
|
|
133
|
-
def test_blockwise(self):
|
|
134
|
-
inp = pt.tensor(shape=(5, 4))
|
|
135
|
-
|
|
136
|
-
invalid_out = inp @ pt.ones((4, 3))
|
|
137
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
138
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
139
|
-
|
|
140
|
-
out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3))
|
|
141
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
142
|
-
assert dims == (0, 1, None, None)
|
|
143
|
-
|
|
144
|
-
def test_random_variable(self):
|
|
145
|
-
inp = pt.tensor(shape=(5, 4, 3))
|
|
146
|
-
|
|
147
|
-
out1 = pt.random.normal(loc=inp)
|
|
148
|
-
out2 = pt.random.categorical(p=inp[..., None])
|
|
149
|
-
out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1))
|
|
150
|
-
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3])
|
|
151
|
-
assert dims1 == (0, 1, 2)
|
|
152
|
-
assert dims2 == (0, 1, 2)
|
|
153
|
-
assert dims3 == (0, 1, 2, None)
|
|
154
|
-
|
|
155
|
-
invalid_out = pt.random.categorical(p=inp)
|
|
156
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
157
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
158
|
-
|
|
159
|
-
invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
|
|
160
|
-
with pytest.raises(ValueError, match="Use of known dimensions"):
|
|
161
|
-
subgraph_batch_dim_connection(inp, [invalid_out])
|
|
162
|
-
|
|
163
|
-
def test_symbolic_random_variable(self):
|
|
164
|
-
inp = pt.tensor(shape=(4, 3, 2))
|
|
165
|
-
|
|
166
|
-
# Test univariate
|
|
167
|
-
out = CustomDist.dist(
|
|
168
|
-
inp,
|
|
169
|
-
dist=lambda mu, size: pt.random.normal(loc=mu, size=size),
|
|
170
|
-
)
|
|
171
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
172
|
-
assert dims == (0, 1, 2)
|
|
173
|
-
|
|
174
|
-
# Test multivariate
|
|
175
|
-
def dist(mu, size):
|
|
176
|
-
if isinstance(size.type, NoneTypeT):
|
|
177
|
-
size = mu.shape
|
|
178
|
-
return pt.random.normal(loc=mu[..., None], size=(*size, 2))
|
|
179
|
-
|
|
180
|
-
out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)")
|
|
181
|
-
[dims] = subgraph_batch_dim_connection(inp, [out])
|
|
182
|
-
assert dims == (0, 1, 2, None)
|