pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,867 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
from contextlib import suppress as does_not_warn
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pymc as pm
|
|
8
|
+
import pytensor.tensor as pt
|
|
9
|
+
import pytest
|
|
10
|
+
|
|
11
|
+
from arviz import InferenceData, dict_to_dataset
|
|
12
|
+
from pymc.distributions import transforms
|
|
13
|
+
from pymc.distributions.transforms import ordered
|
|
14
|
+
from pymc.model.fgraph import fgraph_from_model
|
|
15
|
+
from pymc.pytensorf import inputvars
|
|
16
|
+
from pymc.util import UNSET
|
|
17
|
+
from scipy.special import log_softmax, logsumexp
|
|
18
|
+
from scipy.stats import halfnorm, norm
|
|
19
|
+
|
|
20
|
+
from pymc_extras.model.marginal.marginal_model import (
|
|
21
|
+
MarginalModel,
|
|
22
|
+
marginalize,
|
|
23
|
+
)
|
|
24
|
+
from tests.utils import equal_computations_up_to_root
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_basic_marginalized_rv():
|
|
28
|
+
data = [2] * 5
|
|
29
|
+
|
|
30
|
+
with MarginalModel() as m:
|
|
31
|
+
sigma = pm.HalfNormal("sigma")
|
|
32
|
+
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
|
|
33
|
+
mu = pt.switch(
|
|
34
|
+
pt.eq(idx, 0),
|
|
35
|
+
-1.0,
|
|
36
|
+
pt.switch(
|
|
37
|
+
pt.eq(idx, 1),
|
|
38
|
+
0.0,
|
|
39
|
+
1.0,
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
y = pm.Normal("y", mu=mu, sigma=sigma)
|
|
43
|
+
z = pm.Normal("z", y, observed=data)
|
|
44
|
+
|
|
45
|
+
m.marginalize([idx])
|
|
46
|
+
assert idx not in m.free_RVs
|
|
47
|
+
assert [rv.name for rv in m.marginalized_rvs] == ["idx"]
|
|
48
|
+
|
|
49
|
+
# Test logp
|
|
50
|
+
with pm.Model() as m_ref:
|
|
51
|
+
sigma = pm.HalfNormal("sigma")
|
|
52
|
+
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
|
|
53
|
+
z = pm.Normal("z", y, observed=data)
|
|
54
|
+
|
|
55
|
+
test_point = m_ref.initial_point()
|
|
56
|
+
ref_logp = m_ref.compile_logp()(test_point)
|
|
57
|
+
ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point)
|
|
58
|
+
|
|
59
|
+
# Assert we can marginalize and unmarginalize internally non-destructively
|
|
60
|
+
for i in range(3):
|
|
61
|
+
np.testing.assert_almost_equal(
|
|
62
|
+
m.compile_logp()(test_point),
|
|
63
|
+
ref_logp,
|
|
64
|
+
)
|
|
65
|
+
np.testing.assert_almost_equal(
|
|
66
|
+
m.compile_dlogp([m["y"]])(test_point),
|
|
67
|
+
ref_dlogp,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_one_to_one_marginalized_rvs():
|
|
72
|
+
"""Test case with multiple, independent marginalized RVs."""
|
|
73
|
+
with MarginalModel() as m:
|
|
74
|
+
sigma = pm.HalfNormal("sigma")
|
|
75
|
+
idx1 = pm.Bernoulli("idx1", p=0.75)
|
|
76
|
+
x = pm.Normal("x", mu=idx1, sigma=sigma)
|
|
77
|
+
idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,))
|
|
78
|
+
y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,))
|
|
79
|
+
|
|
80
|
+
m.marginalize([idx1, idx2])
|
|
81
|
+
m["x"].owner is not m["y"].owner
|
|
82
|
+
_m = m.clone()._marginalize()
|
|
83
|
+
_m["x"].owner is not _m["y"].owner
|
|
84
|
+
|
|
85
|
+
with pm.Model() as m_ref:
|
|
86
|
+
sigma = pm.HalfNormal("sigma")
|
|
87
|
+
x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma)
|
|
88
|
+
y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,))
|
|
89
|
+
|
|
90
|
+
# Test logp
|
|
91
|
+
test_point = m_ref.initial_point()
|
|
92
|
+
x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point)
|
|
93
|
+
x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point)
|
|
94
|
+
np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum())
|
|
95
|
+
np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_one_to_many_marginalized_rvs():
|
|
99
|
+
"""Test that marginalization works when there is more than one dependent RV"""
|
|
100
|
+
with MarginalModel() as m:
|
|
101
|
+
sigma = pm.HalfNormal("sigma")
|
|
102
|
+
idx = pm.Bernoulli("idx", p=0.75)
|
|
103
|
+
x = pm.Normal("x", mu=idx, sigma=sigma)
|
|
104
|
+
y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,))
|
|
105
|
+
|
|
106
|
+
ref_logp_x_y_fn = m.compile_logp([idx, x, y])
|
|
107
|
+
|
|
108
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
109
|
+
m.marginalize([idx])
|
|
110
|
+
|
|
111
|
+
m["x"].owner is not m["y"].owner
|
|
112
|
+
_m = m.clone()._marginalize()
|
|
113
|
+
_m["x"].owner is _m["y"].owner
|
|
114
|
+
|
|
115
|
+
tp = m.initial_point()
|
|
116
|
+
ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)])
|
|
117
|
+
logp_x_y = m.compile_logp([x, y])(tp)
|
|
118
|
+
np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_one_to_many_unaligned_marginalized_rvs():
|
|
122
|
+
"""Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
|
|
123
|
+
|
|
124
|
+
def build_model(build_batched: bool):
|
|
125
|
+
with MarginalModel() as m:
|
|
126
|
+
if build_batched:
|
|
127
|
+
idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
|
|
128
|
+
else:
|
|
129
|
+
idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)]
|
|
130
|
+
idx = pt.stack(idxs, axis=0).reshape((3, 2))
|
|
131
|
+
|
|
132
|
+
x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1))
|
|
133
|
+
y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2))
|
|
134
|
+
|
|
135
|
+
return m
|
|
136
|
+
|
|
137
|
+
m = build_model(build_batched=True)
|
|
138
|
+
ref_m = build_model(build_batched=False)
|
|
139
|
+
|
|
140
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
141
|
+
m.marginalize(["idx"])
|
|
142
|
+
ref_m.marginalize([f"idx_{i}" for i in range(6)])
|
|
143
|
+
|
|
144
|
+
test_point = m.initial_point()
|
|
145
|
+
np.testing.assert_allclose(
|
|
146
|
+
m.compile_logp()(test_point),
|
|
147
|
+
ref_m.compile_logp()(test_point),
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_many_to_one_marginalized_rvs():
|
|
152
|
+
"""Test when random variables depend on multiple marginalized variables"""
|
|
153
|
+
with MarginalModel() as m:
|
|
154
|
+
x = pm.Bernoulli("x", 0.1)
|
|
155
|
+
y = pm.Bernoulli("y", 0.3)
|
|
156
|
+
z = pm.DiracDelta("z", c=x + y)
|
|
157
|
+
|
|
158
|
+
m.marginalize([x, y])
|
|
159
|
+
logp = m.compile_logp()
|
|
160
|
+
|
|
161
|
+
np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7)
|
|
162
|
+
np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
|
|
163
|
+
np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@pytest.mark.parametrize("batched", (False, "left", "right"))
|
|
167
|
+
def test_nested_marginalized_rvs(batched):
|
|
168
|
+
"""Test that marginalization works when there are nested marginalized RVs"""
|
|
169
|
+
|
|
170
|
+
def build_model(build_batched: bool) -> MarginalModel:
|
|
171
|
+
idx_shape = (3,) if build_batched else ()
|
|
172
|
+
sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
|
|
173
|
+
|
|
174
|
+
with MarginalModel() as m:
|
|
175
|
+
sigma = pm.HalfNormal("sigma")
|
|
176
|
+
|
|
177
|
+
idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
|
|
178
|
+
dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
|
|
179
|
+
|
|
180
|
+
sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95)
|
|
181
|
+
if build_batched and batched == "right":
|
|
182
|
+
sub_idx_p = sub_idx_p[..., None]
|
|
183
|
+
dep = dep[..., None]
|
|
184
|
+
sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape)
|
|
185
|
+
sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma)
|
|
186
|
+
|
|
187
|
+
return m
|
|
188
|
+
|
|
189
|
+
m = build_model(build_batched=batched)
|
|
190
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
191
|
+
m.marginalize(["idx", "sub_idx"])
|
|
192
|
+
assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"]
|
|
193
|
+
|
|
194
|
+
# Test logp
|
|
195
|
+
ref_m = build_model(build_batched=False)
|
|
196
|
+
ref_logp_fn = ref_m.compile_logp(
|
|
197
|
+
vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]]
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
test_point = ref_m.initial_point()
|
|
201
|
+
test_point["dep"] = np.full_like(test_point["dep"], 1000)
|
|
202
|
+
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
|
|
203
|
+
ref_logp = logsumexp(
|
|
204
|
+
[
|
|
205
|
+
ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}})
|
|
206
|
+
for idx in (0, 1)
|
|
207
|
+
for sub_idxs in itertools.product((0, 1), repeat=5)
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
if batched:
|
|
211
|
+
ref_logp *= 3
|
|
212
|
+
|
|
213
|
+
test_point = m.initial_point()
|
|
214
|
+
test_point["dep"] = np.full_like(test_point["dep"], 1000)
|
|
215
|
+
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
|
|
216
|
+
logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point)
|
|
217
|
+
|
|
218
|
+
np.testing.assert_almost_equal(logp, ref_logp)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@pytest.mark.parametrize("advanced_indexing", (False, True))
|
|
222
|
+
def test_marginalized_index_as_key(advanced_indexing):
|
|
223
|
+
"""Test we can marginalize graphs where indexing is used as a mapping."""
|
|
224
|
+
|
|
225
|
+
w = [0.1, 0.3, 0.6]
|
|
226
|
+
mu = pt.as_tensor([-1, 0, 1])
|
|
227
|
+
|
|
228
|
+
if advanced_indexing:
|
|
229
|
+
y_val = pt.as_tensor([[-1, -1], [0, 1]])
|
|
230
|
+
shape = (2, 2)
|
|
231
|
+
else:
|
|
232
|
+
y_val = -1
|
|
233
|
+
shape = ()
|
|
234
|
+
|
|
235
|
+
with MarginalModel() as m:
|
|
236
|
+
x = pm.Categorical("x", p=w, shape=shape)
|
|
237
|
+
y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
|
|
238
|
+
|
|
239
|
+
m.marginalize(x)
|
|
240
|
+
|
|
241
|
+
marginal_logp = m.compile_logp(sum=False)({})[0]
|
|
242
|
+
ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
|
|
243
|
+
|
|
244
|
+
np.testing.assert_allclose(marginal_logp, ref_logp)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def test_marginalized_index_as_value_and_key():
|
|
248
|
+
"""Test we can marginalize graphs were marginalized_rv is indexed."""
|
|
249
|
+
|
|
250
|
+
def build_model(build_batched: bool) -> MarginalModel:
|
|
251
|
+
with MarginalModel() as m:
|
|
252
|
+
if build_batched:
|
|
253
|
+
latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
|
|
254
|
+
else:
|
|
255
|
+
latent_state = pm.math.stack(
|
|
256
|
+
[pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)]
|
|
257
|
+
)
|
|
258
|
+
# latent state is used as the indexed variable
|
|
259
|
+
latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0])
|
|
260
|
+
picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6])
|
|
261
|
+
# picked intensity is used as the indexing variable
|
|
262
|
+
pm.Normal(
|
|
263
|
+
"intensity",
|
|
264
|
+
mu=latent_intensities[:, picked_intensity],
|
|
265
|
+
observed=[0.5, 1.5, 5.0, 15.0],
|
|
266
|
+
)
|
|
267
|
+
return m
|
|
268
|
+
|
|
269
|
+
# We compare with the equivalent but less efficient batched model
|
|
270
|
+
m = build_model(build_batched=True)
|
|
271
|
+
ref_m = build_model(build_batched=False)
|
|
272
|
+
|
|
273
|
+
m.marginalize(["latent_state"])
|
|
274
|
+
ref_m.marginalize([f"latent_state_{i}" for i in range(4)])
|
|
275
|
+
test_point = {"picked_intensity": 1}
|
|
276
|
+
np.testing.assert_allclose(
|
|
277
|
+
m.compile_logp()(test_point),
|
|
278
|
+
ref_m.compile_logp()(test_point),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
m.marginalize(["picked_intensity"])
|
|
282
|
+
ref_m.marginalize(["picked_intensity"])
|
|
283
|
+
test_point = {}
|
|
284
|
+
np.testing.assert_allclose(
|
|
285
|
+
m.compile_logp()(test_point),
|
|
286
|
+
ref_m.compile_logp()(test_point),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class TestNotSupportedMixedDims:
|
|
291
|
+
"""Test lack of support for models where batch dims of marginalized variables are mixed."""
|
|
292
|
+
|
|
293
|
+
def test_mixed_dims_via_transposed_dot(self):
|
|
294
|
+
with MarginalModel() as m:
|
|
295
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
296
|
+
y = pm.Normal("y", mu=idx @ idx.T)
|
|
297
|
+
with pytest.raises(NotImplementedError):
|
|
298
|
+
m.marginalize(idx)
|
|
299
|
+
|
|
300
|
+
def test_mixed_dims_via_indexing(self):
|
|
301
|
+
mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]])
|
|
302
|
+
|
|
303
|
+
with MarginalModel() as m:
|
|
304
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
305
|
+
y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx])
|
|
306
|
+
with pytest.raises(NotImplementedError):
|
|
307
|
+
m.marginalize(idx)
|
|
308
|
+
|
|
309
|
+
with MarginalModel() as m:
|
|
310
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
311
|
+
y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx])
|
|
312
|
+
with pytest.raises(NotImplementedError):
|
|
313
|
+
m.marginalize(idx)
|
|
314
|
+
|
|
315
|
+
with MarginalModel() as m:
|
|
316
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
317
|
+
mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable(
|
|
318
|
+
mean[None, :][:, idx], 0
|
|
319
|
+
)
|
|
320
|
+
y = pm.Normal("y", mu=mu)
|
|
321
|
+
with pytest.raises(NotImplementedError):
|
|
322
|
+
m.marginalize(idx)
|
|
323
|
+
|
|
324
|
+
with MarginalModel() as m:
|
|
325
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
326
|
+
y = pm.Normal("y", mu=idx[0] + idx[1])
|
|
327
|
+
with pytest.raises(NotImplementedError):
|
|
328
|
+
m.marginalize(idx)
|
|
329
|
+
|
|
330
|
+
def test_mixed_dims_via_vector_indexing(self):
|
|
331
|
+
with MarginalModel() as m:
|
|
332
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
333
|
+
y = pm.Normal("y", mu=idx[[0, 1, 0, 0]])
|
|
334
|
+
with pytest.raises(NotImplementedError):
|
|
335
|
+
m.marginalize(idx)
|
|
336
|
+
|
|
337
|
+
with MarginalModel() as m:
|
|
338
|
+
idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2))
|
|
339
|
+
y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)])
|
|
340
|
+
with pytest.raises(NotImplementedError):
|
|
341
|
+
m.marginalize(idx)
|
|
342
|
+
|
|
343
|
+
def test_mixed_dims_via_support_dimension(self):
|
|
344
|
+
with MarginalModel() as m:
|
|
345
|
+
x = pm.Bernoulli("x", p=0.7, shape=3)
|
|
346
|
+
y = pm.Dirichlet("y", a=x * 10 + 1)
|
|
347
|
+
with pytest.raises(NotImplementedError):
|
|
348
|
+
m.marginalize(x)
|
|
349
|
+
|
|
350
|
+
def test_mixed_dims_via_nested_marginalization(self):
|
|
351
|
+
with MarginalModel() as m:
|
|
352
|
+
x = pm.Bernoulli("x", p=0.7, shape=(3,))
|
|
353
|
+
y = pm.Bernoulli("y", p=0.7, shape=(2,))
|
|
354
|
+
z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
|
|
355
|
+
|
|
356
|
+
with pytest.raises(NotImplementedError):
|
|
357
|
+
m.marginalize([x, y])
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def test_marginalized_deterministic_and_potential():
|
|
361
|
+
rng = np.random.default_rng(299)
|
|
362
|
+
|
|
363
|
+
with MarginalModel() as m:
|
|
364
|
+
x = pm.Bernoulli("x", p=0.7)
|
|
365
|
+
y = pm.Normal("y", x)
|
|
366
|
+
z = pm.Normal("z", x)
|
|
367
|
+
det = pm.Deterministic("det", y + z)
|
|
368
|
+
pot = pm.Potential("pot", y + z + 1)
|
|
369
|
+
|
|
370
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
371
|
+
m.marginalize([x])
|
|
372
|
+
|
|
373
|
+
y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng)
|
|
374
|
+
np.testing.assert_almost_equal(y_draw + z_draw, det_draw)
|
|
375
|
+
np.testing.assert_almost_equal(det_draw, pot_draw - 1)
|
|
376
|
+
|
|
377
|
+
y_value = m.rvs_to_values[y]
|
|
378
|
+
z_value = m.rvs_to_values[z]
|
|
379
|
+
det_value, pot_value = m.replace_rvs_by_values([det, pot])
|
|
380
|
+
assert set(inputvars([det_value, pot_value])) == {y_value, z_value}
|
|
381
|
+
assert det_value.eval({y_value: 2, z_value: 5}) == 7
|
|
382
|
+
assert pot_value.eval({y_value: 2, z_value: 5}) == 8
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def test_not_supported_marginalized_deterministic_and_potential():
|
|
386
|
+
with MarginalModel() as m:
|
|
387
|
+
x = pm.Bernoulli("x", p=0.7)
|
|
388
|
+
y = pm.Normal("y", x)
|
|
389
|
+
det = pm.Deterministic("det", x + y)
|
|
390
|
+
|
|
391
|
+
with pytest.raises(
|
|
392
|
+
NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det"
|
|
393
|
+
):
|
|
394
|
+
m.marginalize([x])
|
|
395
|
+
|
|
396
|
+
with MarginalModel() as m:
|
|
397
|
+
x = pm.Bernoulli("x", p=0.7)
|
|
398
|
+
y = pm.Normal("y", x)
|
|
399
|
+
pot = pm.Potential("pot", x + y)
|
|
400
|
+
|
|
401
|
+
with pytest.raises(
|
|
402
|
+
NotImplementedError, match="Cannot marginalize x due to dependent Potential pot"
|
|
403
|
+
):
|
|
404
|
+
m.marginalize([x])
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@pytest.mark.parametrize(
|
|
408
|
+
"transform, expected_warning",
|
|
409
|
+
(
|
|
410
|
+
(None, does_not_warn()),
|
|
411
|
+
(UNSET, does_not_warn()),
|
|
412
|
+
(transforms.log, does_not_warn()),
|
|
413
|
+
(transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()),
|
|
414
|
+
(
|
|
415
|
+
transforms.Interval(0, 1),
|
|
416
|
+
pytest.warns(
|
|
417
|
+
UserWarning, match="which depends on the marginalized idx may no longer work"
|
|
418
|
+
),
|
|
419
|
+
),
|
|
420
|
+
(
|
|
421
|
+
transforms.Chain([transforms.log, transforms.Interval(0, 1)]),
|
|
422
|
+
pytest.warns(
|
|
423
|
+
UserWarning, match="which depends on the marginalized idx may no longer work"
|
|
424
|
+
),
|
|
425
|
+
),
|
|
426
|
+
),
|
|
427
|
+
)
|
|
428
|
+
def test_marginalized_transforms(transform, expected_warning):
|
|
429
|
+
w = [0.1, 0.3, 0.6]
|
|
430
|
+
data = [0, 5, 10]
|
|
431
|
+
initval = 0.5 # Value that will be negative on the unconstrained space
|
|
432
|
+
|
|
433
|
+
with pm.Model() as m_ref:
|
|
434
|
+
sigma = pm.Mixture(
|
|
435
|
+
"sigma",
|
|
436
|
+
w=w,
|
|
437
|
+
comp_dists=pm.HalfNormal.dist([1, 2, 3]),
|
|
438
|
+
initval=initval,
|
|
439
|
+
default_transform=transform,
|
|
440
|
+
)
|
|
441
|
+
y = pm.Normal("y", 0, sigma, observed=data)
|
|
442
|
+
|
|
443
|
+
with MarginalModel() as m:
|
|
444
|
+
idx = pm.Categorical("idx", p=w)
|
|
445
|
+
sigma = pm.HalfNormal(
|
|
446
|
+
"sigma",
|
|
447
|
+
pt.switch(
|
|
448
|
+
pt.eq(idx, 0),
|
|
449
|
+
1,
|
|
450
|
+
pt.switch(
|
|
451
|
+
pt.eq(idx, 1),
|
|
452
|
+
2,
|
|
453
|
+
3,
|
|
454
|
+
),
|
|
455
|
+
),
|
|
456
|
+
initval=initval,
|
|
457
|
+
default_transform=transform,
|
|
458
|
+
)
|
|
459
|
+
y = pm.Normal("y", 0, sigma, observed=data)
|
|
460
|
+
|
|
461
|
+
with expected_warning:
|
|
462
|
+
m.marginalize([idx])
|
|
463
|
+
|
|
464
|
+
ip = m.initial_point()
|
|
465
|
+
if transform is not None:
|
|
466
|
+
if transform is UNSET:
|
|
467
|
+
transform_name = "log"
|
|
468
|
+
else:
|
|
469
|
+
transform_name = transform.name
|
|
470
|
+
assert f"sigma_{transform_name}__" in ip
|
|
471
|
+
np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def test_data_container():
|
|
475
|
+
"""Test that MarginalModel can handle Data containers."""
|
|
476
|
+
with MarginalModel(coords={"obs": [0]}) as marginal_m:
|
|
477
|
+
x = pm.Data("x", 2.5)
|
|
478
|
+
idx = pm.Bernoulli("idx", p=0.7, dims="obs")
|
|
479
|
+
y = pm.Normal("y", idx * x, dims="obs")
|
|
480
|
+
|
|
481
|
+
marginal_m.marginalize([idx])
|
|
482
|
+
|
|
483
|
+
logp_fn = marginal_m.compile_logp()
|
|
484
|
+
|
|
485
|
+
with pm.Model(coords={"obs": [0]}) as m_ref:
|
|
486
|
+
x = pm.Data("x", 2.5)
|
|
487
|
+
y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs")
|
|
488
|
+
|
|
489
|
+
ref_logp_fn = m_ref.compile_logp()
|
|
490
|
+
|
|
491
|
+
for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1):
|
|
492
|
+
for m in (marginal_m, m_ref):
|
|
493
|
+
m.set_dim("obs", new_length=i, coord_values=tuple(range(i)))
|
|
494
|
+
pm.set_data({"x": x_val}, model=m)
|
|
495
|
+
|
|
496
|
+
ip = marginal_m.initial_point()
|
|
497
|
+
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def test_mutable_indexing_jax_backend():
|
|
501
|
+
pytest.importorskip("jax")
|
|
502
|
+
from pymc.sampling.jax import get_jaxified_logp
|
|
503
|
+
|
|
504
|
+
with MarginalModel() as model:
|
|
505
|
+
data = pm.Data("data", np.zeros(10))
|
|
506
|
+
|
|
507
|
+
cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
|
|
508
|
+
cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5))
|
|
509
|
+
|
|
510
|
+
is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
|
|
511
|
+
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
|
|
512
|
+
model.marginalize(["is_outlier"])
|
|
513
|
+
get_jaxified_logp(model)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def test_marginal_model_func():
|
|
517
|
+
def create_model(model_class):
|
|
518
|
+
with model_class(coords={"trial": range(10)}) as m:
|
|
519
|
+
idx = pm.Bernoulli("idx", p=0.5, dims="trial")
|
|
520
|
+
mu = pt.where(idx, 1, -1)
|
|
521
|
+
sigma = pm.HalfNormal("sigma")
|
|
522
|
+
y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
|
|
523
|
+
return m
|
|
524
|
+
|
|
525
|
+
marginal_m = marginalize(create_model(pm.Model), ["idx"])
|
|
526
|
+
assert isinstance(marginal_m, MarginalModel)
|
|
527
|
+
|
|
528
|
+
reference_m = create_model(MarginalModel)
|
|
529
|
+
reference_m.marginalize(["idx"])
|
|
530
|
+
|
|
531
|
+
# Check forward graph representation is the same
|
|
532
|
+
marginal_fgraph, _ = fgraph_from_model(marginal_m)
|
|
533
|
+
reference_fgraph, _ = fgraph_from_model(reference_m)
|
|
534
|
+
assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)
|
|
535
|
+
|
|
536
|
+
# Check logp graph is the same
|
|
537
|
+
# This fails because OpFromGraphs comparison is broken
|
|
538
|
+
# assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
|
|
539
|
+
ip = marginal_m.initial_point()
|
|
540
|
+
np.testing.assert_allclose(
|
|
541
|
+
marginal_m.compile_logp()(ip),
|
|
542
|
+
reference_m.compile_logp()(ip),
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
class TestFullModels:
|
|
547
|
+
@pytest.fixture
|
|
548
|
+
def disaster_model(self):
|
|
549
|
+
# fmt: off
|
|
550
|
+
disaster_data = pd.Series(
|
|
551
|
+
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
|
|
552
|
+
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
|
|
553
|
+
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
|
|
554
|
+
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
|
|
555
|
+
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
|
|
556
|
+
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
|
|
557
|
+
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
|
|
558
|
+
)
|
|
559
|
+
# fmt: on
|
|
560
|
+
years = np.arange(1851, 1962)
|
|
561
|
+
|
|
562
|
+
with MarginalModel() as disaster_model:
|
|
563
|
+
switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
|
|
564
|
+
early_rate = pm.Exponential("early_rate", 1.0, initval=3)
|
|
565
|
+
late_rate = pm.Exponential("late_rate", 1.0, initval=1)
|
|
566
|
+
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
|
|
567
|
+
with pytest.warns(Warning):
|
|
568
|
+
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
|
|
569
|
+
|
|
570
|
+
return disaster_model, years
|
|
571
|
+
|
|
572
|
+
def test_change_point_model(self, disaster_model):
|
|
573
|
+
m, years = disaster_model
|
|
574
|
+
|
|
575
|
+
ip = m.initial_point()
|
|
576
|
+
ip.pop("switchpoint")
|
|
577
|
+
ref_logp_fn = m.compile_logp(
|
|
578
|
+
[m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]]
|
|
579
|
+
)
|
|
580
|
+
ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years])
|
|
581
|
+
|
|
582
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
583
|
+
m.marginalize(m["switchpoint"])
|
|
584
|
+
|
|
585
|
+
logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip)
|
|
586
|
+
np.testing.assert_almost_equal(logp, ref_logp)
|
|
587
|
+
|
|
588
|
+
@pytest.mark.slow
|
|
589
|
+
def test_change_point_model_sampling(self, disaster_model):
|
|
590
|
+
m, _ = disaster_model
|
|
591
|
+
|
|
592
|
+
rng = np.random.default_rng(211)
|
|
593
|
+
|
|
594
|
+
with m:
|
|
595
|
+
before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
|
|
596
|
+
sample=("draw", "chain")
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
600
|
+
m.marginalize([m["switchpoint"]])
|
|
601
|
+
|
|
602
|
+
with m:
|
|
603
|
+
after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
|
|
604
|
+
sample=("draw", "chain")
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
np.testing.assert_allclose(
|
|
608
|
+
before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2
|
|
609
|
+
)
|
|
610
|
+
np.testing.assert_allclose(
|
|
611
|
+
before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2
|
|
612
|
+
)
|
|
613
|
+
np.testing.assert_allclose(
|
|
614
|
+
before_marg["disasters_unobserved"].mean(),
|
|
615
|
+
after_marg["disasters_unobserved"].mean(),
|
|
616
|
+
rtol=1e-2,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
@pytest.mark.parametrize("univariate", (True, False))
|
|
620
|
+
def test_vector_univariate_mixture(self, univariate):
|
|
621
|
+
with MarginalModel() as m:
|
|
622
|
+
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
|
|
623
|
+
|
|
624
|
+
def dist(idx, size):
|
|
625
|
+
return pm.math.switch(
|
|
626
|
+
pm.math.eq(idx, 0),
|
|
627
|
+
pm.Normal.dist([-10, -10], 1),
|
|
628
|
+
pm.Normal.dist([10, 10], 1),
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
pm.CustomDist("norm", idx, dist=dist)
|
|
632
|
+
|
|
633
|
+
m.marginalize(idx)
|
|
634
|
+
logp_fn = m.compile_logp()
|
|
635
|
+
|
|
636
|
+
if univariate:
|
|
637
|
+
with pm.Model() as ref_m:
|
|
638
|
+
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
|
|
639
|
+
else:
|
|
640
|
+
with pm.Model() as ref_m:
|
|
641
|
+
pm.Mixture(
|
|
642
|
+
"norm",
|
|
643
|
+
w=[0.5, 0.5],
|
|
644
|
+
comp_dists=[
|
|
645
|
+
pm.MvNormal.dist([-10, -10], np.eye(2)),
|
|
646
|
+
pm.MvNormal.dist([10, 10], np.eye(2)),
|
|
647
|
+
],
|
|
648
|
+
shape=(2,),
|
|
649
|
+
)
|
|
650
|
+
ref_logp_fn = ref_m.compile_logp()
|
|
651
|
+
|
|
652
|
+
for test_value in (
|
|
653
|
+
[-10, -10],
|
|
654
|
+
[10, 10],
|
|
655
|
+
[-10, 10],
|
|
656
|
+
[-10, 10],
|
|
657
|
+
):
|
|
658
|
+
pt = {"norm": test_value}
|
|
659
|
+
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
|
|
660
|
+
|
|
661
|
+
def test_k_censored_clusters_model(self):
|
|
662
|
+
def build_model(build_batched: bool) -> MarginalModel:
|
|
663
|
+
data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
|
|
664
|
+
nobs = data.shape[0]
|
|
665
|
+
n_clusters = 5
|
|
666
|
+
coords = {
|
|
667
|
+
"cluster": range(n_clusters),
|
|
668
|
+
"ndim": ("x", "y"),
|
|
669
|
+
"obs": range(nobs),
|
|
670
|
+
}
|
|
671
|
+
with MarginalModel(coords=coords) as m:
|
|
672
|
+
if build_batched:
|
|
673
|
+
idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
|
|
674
|
+
else:
|
|
675
|
+
idx = pm.math.stack(
|
|
676
|
+
[
|
|
677
|
+
pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters)
|
|
678
|
+
for i in range(nobs)
|
|
679
|
+
]
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
mu_x = pm.Normal(
|
|
683
|
+
"mu_x",
|
|
684
|
+
dims=["cluster"],
|
|
685
|
+
transform=ordered,
|
|
686
|
+
initval=np.linspace(-1, 1, n_clusters),
|
|
687
|
+
)
|
|
688
|
+
mu_y = pm.Normal("mu_y", dims=["cluster"])
|
|
689
|
+
mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
|
|
690
|
+
mu_indexed = mu[idx, :]
|
|
691
|
+
|
|
692
|
+
sigma = pm.HalfNormal("sigma")
|
|
693
|
+
|
|
694
|
+
y = pm.Censored(
|
|
695
|
+
"y",
|
|
696
|
+
dist=pm.Normal.dist(mu_indexed, sigma),
|
|
697
|
+
lower=-3,
|
|
698
|
+
upper=3,
|
|
699
|
+
observed=data,
|
|
700
|
+
dims=["obs", "ndim"],
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
return m
|
|
704
|
+
|
|
705
|
+
m = build_model(build_batched=True)
|
|
706
|
+
ref_m = build_model(build_batched=False)
|
|
707
|
+
|
|
708
|
+
m.marginalize([m["idx"]])
|
|
709
|
+
ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")])
|
|
710
|
+
|
|
711
|
+
test_point = m.initial_point()
|
|
712
|
+
np.testing.assert_almost_equal(
|
|
713
|
+
m.compile_logp()(test_point),
|
|
714
|
+
ref_m.compile_logp()(test_point),
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
class TestRecoverMarginals:
|
|
719
|
+
def test_basic(self):
|
|
720
|
+
with MarginalModel() as m:
|
|
721
|
+
sigma = pm.HalfNormal("sigma")
|
|
722
|
+
p = np.array([0.5, 0.2, 0.3])
|
|
723
|
+
k = pm.Categorical("k", p=p)
|
|
724
|
+
mu = np.array([-3.0, 0.0, 3.0])
|
|
725
|
+
mu_ = pt.as_tensor_variable(mu)
|
|
726
|
+
y = pm.Normal("y", mu=mu_[k], sigma=sigma)
|
|
727
|
+
|
|
728
|
+
m.marginalize([k])
|
|
729
|
+
|
|
730
|
+
rng = np.random.default_rng(211)
|
|
731
|
+
|
|
732
|
+
with m:
|
|
733
|
+
prior = pm.sample_prior_predictive(
|
|
734
|
+
draws=20,
|
|
735
|
+
random_seed=rng,
|
|
736
|
+
return_inferencedata=False,
|
|
737
|
+
)
|
|
738
|
+
idata = InferenceData(posterior=dict_to_dataset(prior))
|
|
739
|
+
|
|
740
|
+
idata = m.recover_marginals(idata, return_samples=True)
|
|
741
|
+
post = idata.posterior
|
|
742
|
+
assert "k" in post
|
|
743
|
+
assert "lp_k" in post
|
|
744
|
+
assert post.k.shape == post.y.shape
|
|
745
|
+
assert post.lp_k.shape == (*post.k.shape, len(p))
|
|
746
|
+
|
|
747
|
+
def true_logp(y, sigma):
|
|
748
|
+
y = y.repeat(len(p)).reshape(len(y), -1)
|
|
749
|
+
sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
|
|
750
|
+
return log_softmax(
|
|
751
|
+
np.log(p)
|
|
752
|
+
+ norm.logpdf(y, loc=mu, scale=sigma)
|
|
753
|
+
+ halfnorm.logpdf(sigma)
|
|
754
|
+
+ np.log(sigma),
|
|
755
|
+
axis=1,
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
np.testing.assert_almost_equal(
|
|
759
|
+
true_logp(post.y.values.flatten(), post.sigma.values.flatten()),
|
|
760
|
+
post.lp_k[0].values,
|
|
761
|
+
)
|
|
762
|
+
np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0)
|
|
763
|
+
|
|
764
|
+
def test_coords(self):
|
|
765
|
+
"""Test if coords can be recovered with marginalized value had it originally"""
|
|
766
|
+
with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m:
|
|
767
|
+
sigma = pm.HalfNormal("sigma")
|
|
768
|
+
idx = pm.Bernoulli("idx", p=0.75, dims="year")
|
|
769
|
+
x = pm.Normal("x", mu=idx, sigma=sigma, dims="year")
|
|
770
|
+
|
|
771
|
+
m.marginalize([idx])
|
|
772
|
+
rng = np.random.default_rng(211)
|
|
773
|
+
|
|
774
|
+
with m:
|
|
775
|
+
prior = pm.sample_prior_predictive(
|
|
776
|
+
draws=20,
|
|
777
|
+
random_seed=rng,
|
|
778
|
+
return_inferencedata=False,
|
|
779
|
+
)
|
|
780
|
+
idata = InferenceData(
|
|
781
|
+
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
idata = m.recover_marginals(idata, return_samples=True)
|
|
785
|
+
post = idata.posterior
|
|
786
|
+
assert post.idx.dims == ("chain", "draw", "year")
|
|
787
|
+
assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
|
|
788
|
+
|
|
789
|
+
def test_batched(self):
|
|
790
|
+
"""Test that marginalization works for batched random variables"""
|
|
791
|
+
with MarginalModel() as m:
|
|
792
|
+
sigma = pm.HalfNormal("sigma")
|
|
793
|
+
idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
|
|
794
|
+
y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
|
|
795
|
+
|
|
796
|
+
m.marginalize([idx])
|
|
797
|
+
|
|
798
|
+
rng = np.random.default_rng(211)
|
|
799
|
+
|
|
800
|
+
with m:
|
|
801
|
+
prior = pm.sample_prior_predictive(
|
|
802
|
+
draws=20,
|
|
803
|
+
random_seed=rng,
|
|
804
|
+
return_inferencedata=False,
|
|
805
|
+
)
|
|
806
|
+
idata = InferenceData(
|
|
807
|
+
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
idata = m.recover_marginals(idata, return_samples=True)
|
|
811
|
+
post = idata.posterior
|
|
812
|
+
assert post["y"].shape == (1, 20, 2, 3)
|
|
813
|
+
assert post["idx"].shape == (1, 20, 3, 2)
|
|
814
|
+
assert post["lp_idx"].shape == (1, 20, 3, 2, 2)
|
|
815
|
+
|
|
816
|
+
def test_nested(self):
|
|
817
|
+
"""Test that marginalization works when there are nested marginalized RVs"""
|
|
818
|
+
|
|
819
|
+
with MarginalModel() as m:
|
|
820
|
+
idx = pm.Bernoulli("idx", p=0.75)
|
|
821
|
+
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
|
|
822
|
+
sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
|
|
823
|
+
|
|
824
|
+
m.marginalize([idx, sub_idx])
|
|
825
|
+
|
|
826
|
+
rng = np.random.default_rng(211)
|
|
827
|
+
|
|
828
|
+
with m:
|
|
829
|
+
prior = pm.sample_prior_predictive(
|
|
830
|
+
draws=20,
|
|
831
|
+
random_seed=rng,
|
|
832
|
+
return_inferencedata=False,
|
|
833
|
+
)
|
|
834
|
+
idata = InferenceData(posterior=dict_to_dataset(prior))
|
|
835
|
+
|
|
836
|
+
idata = m.recover_marginals(idata, return_samples=True)
|
|
837
|
+
post = idata.posterior
|
|
838
|
+
assert "idx" in post
|
|
839
|
+
assert "lp_idx" in post
|
|
840
|
+
assert post.idx.shape == post.y.shape
|
|
841
|
+
assert post.lp_idx.shape == (*post.idx.shape, 2)
|
|
842
|
+
assert "sub_idx" in post
|
|
843
|
+
assert "lp_sub_idx" in post
|
|
844
|
+
assert post.sub_idx.shape == post.y.shape
|
|
845
|
+
assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2)
|
|
846
|
+
|
|
847
|
+
def true_idx_logp(y):
|
|
848
|
+
idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
|
|
849
|
+
idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
|
|
850
|
+
return log_softmax(np.stack([idx_0, idx_1]).T, axis=1)
|
|
851
|
+
|
|
852
|
+
np.testing.assert_almost_equal(
|
|
853
|
+
true_idx_logp(post.y.values.flatten()),
|
|
854
|
+
post.lp_idx[0].values,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
def true_sub_idx_logp(y):
|
|
858
|
+
sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1))
|
|
859
|
+
sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
|
|
860
|
+
return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1)
|
|
861
|
+
|
|
862
|
+
np.testing.assert_almost_equal(
|
|
863
|
+
true_sub_idx_logp(post.y.values.flatten()),
|
|
864
|
+
post.lp_sub_idx[0].values,
|
|
865
|
+
)
|
|
866
|
+
np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0)
|
|
867
|
+
np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0)
|