pymc-extras 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/laplace.py +10 -7
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.5.dist-info/METADATA +0 -112
- pymc_extras-0.2.5.dist-info/RECORD +0 -108
- pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/distributions/test_transform.py +0 -77
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -181
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -281
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -297
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,967 +0,0 @@
|
|
|
1
|
-
import itertools
|
|
2
|
-
|
|
3
|
-
from contextlib import suppress as does_not_warn
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import pandas as pd
|
|
7
|
-
import pymc as pm
|
|
8
|
-
import pytensor.tensor as pt
|
|
9
|
-
import pytest
|
|
10
|
-
|
|
11
|
-
from arviz import InferenceData, dict_to_dataset
|
|
12
|
-
from pymc import Model, draw
|
|
13
|
-
from pymc.distributions import transforms
|
|
14
|
-
from pymc.distributions.transforms import ordered
|
|
15
|
-
from pymc.initial_point import make_initial_point_expression
|
|
16
|
-
from pymc.pytensorf import constant_fold, inputvars
|
|
17
|
-
from pymc.util import UNSET
|
|
18
|
-
from scipy.special import log_softmax, logsumexp
|
|
19
|
-
from scipy.stats import halfnorm, norm
|
|
20
|
-
|
|
21
|
-
from pymc_extras.model.marginal.distributions import MarginalRV
|
|
22
|
-
from pymc_extras.model.marginal.marginal_model import (
|
|
23
|
-
marginalize,
|
|
24
|
-
recover_marginals,
|
|
25
|
-
unmarginalize,
|
|
26
|
-
)
|
|
27
|
-
from pymc_extras.utils.model_equivalence import equivalent_models
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def test_basic_marginalized_rv():
|
|
31
|
-
data = [2] * 5
|
|
32
|
-
|
|
33
|
-
with Model() as m:
|
|
34
|
-
sigma = pm.HalfNormal("sigma")
|
|
35
|
-
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
|
|
36
|
-
mu = pt.switch(
|
|
37
|
-
pt.eq(idx, 0),
|
|
38
|
-
-1.0,
|
|
39
|
-
pt.switch(
|
|
40
|
-
pt.eq(idx, 1),
|
|
41
|
-
0.0,
|
|
42
|
-
1.0,
|
|
43
|
-
),
|
|
44
|
-
)
|
|
45
|
-
y = pm.Normal("y", mu=mu, sigma=sigma)
|
|
46
|
-
z = pm.Normal("z", y, observed=data)
|
|
47
|
-
|
|
48
|
-
marginal_m = marginalize(m, [idx])
|
|
49
|
-
assert isinstance(marginal_m["y"].owner.op, MarginalRV)
|
|
50
|
-
assert ["idx"] not in [rv.name for rv in marginal_m.free_RVs]
|
|
51
|
-
|
|
52
|
-
# Test forward draws
|
|
53
|
-
y_draws, z_draws = draw(
|
|
54
|
-
[marginal_m["y"], marginal_m["z"]],
|
|
55
|
-
# Make sigma very small to make draws deterministic
|
|
56
|
-
givens={marginal_m["sigma"]: 0.001},
|
|
57
|
-
draws=1000,
|
|
58
|
-
random_seed=54,
|
|
59
|
-
)
|
|
60
|
-
assert sorted(np.unique(y_draws.round()) == [-1.0, 0.0, 1.0])
|
|
61
|
-
assert z_draws[y_draws < 0].mean() < z_draws[y_draws > 0].mean()
|
|
62
|
-
|
|
63
|
-
# Test initial_point
|
|
64
|
-
ips = make_initial_point_expression(
|
|
65
|
-
# Use basic_RVs to include the observed RV
|
|
66
|
-
free_rvs=marginal_m.basic_RVs,
|
|
67
|
-
rvs_to_transforms=marginal_m.rvs_to_transforms,
|
|
68
|
-
initval_strategies={},
|
|
69
|
-
)
|
|
70
|
-
# After simplification, we should have only constants in the graph (expect alloc which isn't constant folded):
|
|
71
|
-
ip_sigma, ip_y, ip_z = constant_fold(ips)
|
|
72
|
-
np.testing.assert_allclose(ip_sigma, 1.0)
|
|
73
|
-
np.testing.assert_allclose(ip_y, 1.0)
|
|
74
|
-
np.testing.assert_allclose(ip_z, np.full((5,), 1.0))
|
|
75
|
-
|
|
76
|
-
marginal_ip = marginal_m.initial_point()
|
|
77
|
-
expected_ip = m.initial_point()
|
|
78
|
-
expected_ip.pop("idx")
|
|
79
|
-
assert marginal_ip == expected_ip
|
|
80
|
-
|
|
81
|
-
# Test logp
|
|
82
|
-
with pm.Model() as ref_m:
|
|
83
|
-
sigma = pm.HalfNormal("sigma")
|
|
84
|
-
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
|
|
85
|
-
z = pm.Normal("z", y, observed=data)
|
|
86
|
-
|
|
87
|
-
np.testing.assert_almost_equal(
|
|
88
|
-
marginal_m.compile_logp()(marginal_ip),
|
|
89
|
-
ref_m.compile_logp()(marginal_ip),
|
|
90
|
-
)
|
|
91
|
-
np.testing.assert_almost_equal(
|
|
92
|
-
marginal_m.compile_dlogp([marginal_m["y"]])(marginal_ip),
|
|
93
|
-
ref_m.compile_dlogp([ref_m["y"]])(marginal_ip),
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def test_one_to_one_marginalized_rvs():
|
|
98
|
-
"""Test case with multiple, independent marginalized RVs."""
|
|
99
|
-
with Model() as m:
|
|
100
|
-
sigma = pm.HalfNormal("sigma")
|
|
101
|
-
idx1 = pm.Bernoulli("idx1", p=0.75)
|
|
102
|
-
x = pm.Normal("x", mu=idx1, sigma=sigma)
|
|
103
|
-
idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,))
|
|
104
|
-
y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,))
|
|
105
|
-
|
|
106
|
-
marginal_m = marginalize(m, [idx1, idx2])
|
|
107
|
-
assert isinstance(marginal_m["x"].owner.op, MarginalRV)
|
|
108
|
-
assert isinstance(marginal_m["y"].owner.op, MarginalRV)
|
|
109
|
-
assert marginal_m["x"].owner is not marginal_m["y"].owner
|
|
110
|
-
|
|
111
|
-
with pm.Model() as ref_m:
|
|
112
|
-
sigma = pm.HalfNormal("sigma")
|
|
113
|
-
x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma)
|
|
114
|
-
y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,))
|
|
115
|
-
|
|
116
|
-
# Test logp
|
|
117
|
-
test_point = ref_m.initial_point()
|
|
118
|
-
x_logp, y_logp = marginal_m.compile_logp(vars=[marginal_m["x"], marginal_m["y"]], sum=False)(
|
|
119
|
-
test_point
|
|
120
|
-
)
|
|
121
|
-
x_ref_log, y_ref_logp = ref_m.compile_logp(vars=[ref_m["x"], ref_m["y"]], sum=False)(test_point)
|
|
122
|
-
np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum())
|
|
123
|
-
np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def test_one_to_many_marginalized_rvs():
|
|
127
|
-
"""Test that marginalization works when there is more than one dependent RV"""
|
|
128
|
-
with Model() as m:
|
|
129
|
-
sigma = pm.HalfNormal("sigma")
|
|
130
|
-
idx = pm.Bernoulli("idx", p=0.75)
|
|
131
|
-
x = pm.Normal("x", mu=idx, sigma=sigma)
|
|
132
|
-
y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,))
|
|
133
|
-
|
|
134
|
-
marginal_m = marginalize(m, [idx])
|
|
135
|
-
|
|
136
|
-
marginal_x = marginal_m["x"]
|
|
137
|
-
marginal_y = marginal_m["y"]
|
|
138
|
-
assert isinstance(marginal_x.owner.op, MarginalRV)
|
|
139
|
-
assert isinstance(marginal_y.owner.op, MarginalRV)
|
|
140
|
-
assert marginal_x.owner is marginal_y.owner
|
|
141
|
-
|
|
142
|
-
ref_logp_x_y_fn = m.compile_logp([idx, x, y])
|
|
143
|
-
tp = marginal_m.initial_point()
|
|
144
|
-
ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)])
|
|
145
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
146
|
-
logp_x_y = marginal_m.compile_logp([marginal_x, marginal_y])(tp)
|
|
147
|
-
np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def test_one_to_many_unaligned_marginalized_rvs():
|
|
151
|
-
"""Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
|
|
152
|
-
|
|
153
|
-
def build_model(build_batched: bool):
|
|
154
|
-
with Model() as m:
|
|
155
|
-
if build_batched:
|
|
156
|
-
idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
|
|
157
|
-
else:
|
|
158
|
-
idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)]
|
|
159
|
-
idx = pt.stack(idxs, axis=0).reshape((3, 2))
|
|
160
|
-
|
|
161
|
-
x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1))
|
|
162
|
-
y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2))
|
|
163
|
-
|
|
164
|
-
return m
|
|
165
|
-
|
|
166
|
-
marginal_m = marginalize(build_model(build_batched=True), ["idx"])
|
|
167
|
-
ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(6)])
|
|
168
|
-
|
|
169
|
-
test_point = marginal_m.initial_point()
|
|
170
|
-
|
|
171
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
172
|
-
np.testing.assert_allclose(
|
|
173
|
-
marginal_m.compile_logp()(test_point),
|
|
174
|
-
ref_m.compile_logp()(test_point),
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def test_many_to_one_marginalized_rvs():
|
|
179
|
-
"""Test when random variables depend on multiple marginalized variables"""
|
|
180
|
-
with Model() as m:
|
|
181
|
-
x = pm.Bernoulli("x", 0.1)
|
|
182
|
-
y = pm.Bernoulli("y", 0.3)
|
|
183
|
-
z = pm.DiracDelta("z", c=x + y)
|
|
184
|
-
|
|
185
|
-
logp_fn = marginalize(m, [x, y]).compile_logp()
|
|
186
|
-
|
|
187
|
-
np.testing.assert_allclose(np.exp(logp_fn({"z": 0})), 0.9 * 0.7)
|
|
188
|
-
np.testing.assert_allclose(np.exp(logp_fn({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
|
|
189
|
-
np.testing.assert_allclose(np.exp(logp_fn({"z": 2})), 0.1 * 0.3)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
@pytest.mark.parametrize("batched", (False, "left", "right"))
|
|
193
|
-
def test_nested_marginalized_rvs(batched):
|
|
194
|
-
"""Test that marginalization works when there are nested marginalized RVs"""
|
|
195
|
-
|
|
196
|
-
def build_model(build_batched: bool) -> Model:
|
|
197
|
-
idx_shape = (3,) if build_batched else ()
|
|
198
|
-
sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
|
|
199
|
-
|
|
200
|
-
with Model() as m:
|
|
201
|
-
sigma = pm.HalfNormal("sigma")
|
|
202
|
-
|
|
203
|
-
idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
|
|
204
|
-
dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
|
|
205
|
-
|
|
206
|
-
sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95)
|
|
207
|
-
if build_batched and batched == "right":
|
|
208
|
-
sub_idx_p = sub_idx_p[..., None]
|
|
209
|
-
dep = dep[..., None]
|
|
210
|
-
sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape)
|
|
211
|
-
sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma)
|
|
212
|
-
|
|
213
|
-
return m
|
|
214
|
-
|
|
215
|
-
marginal_m = marginalize(build_model(build_batched=batched), ["idx", "sub_idx"])
|
|
216
|
-
assert all(rv.name not in ("idx", "sub_idx") for rv in marginal_m.free_RVs)
|
|
217
|
-
|
|
218
|
-
# Test forward draws and initial_point, shouldn't depend on batching, so we only test one case
|
|
219
|
-
if not batched:
|
|
220
|
-
# Test forward draws
|
|
221
|
-
dep_draws, sub_dep_draws = draw(
|
|
222
|
-
[marginal_m["dep"], marginal_m["sub_dep"]],
|
|
223
|
-
# Make sigma very small to make draws deterministic
|
|
224
|
-
givens={marginal_m["sigma"]: 0.001},
|
|
225
|
-
draws=1000,
|
|
226
|
-
random_seed=214,
|
|
227
|
-
)
|
|
228
|
-
assert sorted(np.unique(dep_draws.round()) == [-1000.0, 1000.0])
|
|
229
|
-
assert sorted(np.unique(sub_dep_draws.round()) == [-1000.0, -900.0, 1000.0, 1100.0])
|
|
230
|
-
|
|
231
|
-
# Test initial_point
|
|
232
|
-
ips = make_initial_point_expression(
|
|
233
|
-
free_rvs=marginal_m.free_RVs,
|
|
234
|
-
rvs_to_transforms=marginal_m.rvs_to_transforms,
|
|
235
|
-
initval_strategies={},
|
|
236
|
-
)
|
|
237
|
-
# After simplification, we should have only constants in the graph
|
|
238
|
-
ip_sigma, ip_dep, ip_sub_dep = constant_fold(ips)
|
|
239
|
-
np.testing.assert_allclose(ip_sigma, 1.0)
|
|
240
|
-
np.testing.assert_allclose(ip_dep, 1000.0)
|
|
241
|
-
np.testing.assert_allclose(ip_sub_dep, np.full((5,), 1100.0))
|
|
242
|
-
|
|
243
|
-
# Test logp
|
|
244
|
-
ref_m = build_model(build_batched=False)
|
|
245
|
-
ref_logp_fn = ref_m.compile_logp(
|
|
246
|
-
vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]]
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
test_point = ref_m.initial_point()
|
|
250
|
-
test_point["dep"] = np.full_like(test_point["dep"], 1000)
|
|
251
|
-
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
|
|
252
|
-
ref_logp = logsumexp(
|
|
253
|
-
[
|
|
254
|
-
ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}})
|
|
255
|
-
for idx in (0, 1)
|
|
256
|
-
for sub_idxs in itertools.product((0, 1), repeat=5)
|
|
257
|
-
]
|
|
258
|
-
)
|
|
259
|
-
if batched:
|
|
260
|
-
ref_logp *= 3
|
|
261
|
-
|
|
262
|
-
test_point = marginal_m.initial_point()
|
|
263
|
-
test_point["dep"] = np.full_like(test_point["dep"], 1000)
|
|
264
|
-
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
|
|
265
|
-
|
|
266
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
267
|
-
logp = marginal_m.compile_logp(vars=[marginal_m["dep"], marginal_m["sub_dep"]])(test_point)
|
|
268
|
-
|
|
269
|
-
np.testing.assert_almost_equal(logp, ref_logp)
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def test_interdependent_rvs():
|
|
273
|
-
"""Test Marginalization when dependent RVs are interdependent."""
|
|
274
|
-
with Model() as m:
|
|
275
|
-
idx = pm.Bernoulli("idx", p=0.75)
|
|
276
|
-
x = pm.Normal("x", mu=idx * 2, sigma=1e-3)
|
|
277
|
-
# Y depends on both x and idx
|
|
278
|
-
y = pm.Normal("y", mu=x * idx * 2, sigma=1e-3)
|
|
279
|
-
|
|
280
|
-
marginal_m = marginalize(m, "idx")
|
|
281
|
-
|
|
282
|
-
marginal_x = marginal_m["x"]
|
|
283
|
-
marginal_y = marginal_m["y"]
|
|
284
|
-
assert isinstance(marginal_x.owner.op, MarginalRV)
|
|
285
|
-
assert isinstance(marginal_y.owner.op, MarginalRV)
|
|
286
|
-
assert marginal_x.owner is marginal_y.owner
|
|
287
|
-
|
|
288
|
-
# Test forward draws
|
|
289
|
-
x_draws, y_draws = draw([marginal_x, marginal_y], draws=1000, random_seed=54)
|
|
290
|
-
assert sorted(np.unique(x_draws.round())) == [0, 2]
|
|
291
|
-
assert sorted(np.unique(y_draws.round())) == [0, 4]
|
|
292
|
-
assert np.unique(y_draws[x_draws < 1].round()) == [0]
|
|
293
|
-
assert np.unique(y_draws[x_draws > 1].round()) == [4]
|
|
294
|
-
|
|
295
|
-
# Test initial_point
|
|
296
|
-
ips = make_initial_point_expression(
|
|
297
|
-
free_rvs=marginal_m.free_RVs,
|
|
298
|
-
rvs_to_transforms={},
|
|
299
|
-
initval_strategies={},
|
|
300
|
-
)
|
|
301
|
-
# After simplification, we should have only constants in the graph
|
|
302
|
-
ip_x, ip_y = constant_fold(ips)
|
|
303
|
-
np.testing.assert_allclose(ip_x, 2.0)
|
|
304
|
-
np.testing.assert_allclose(ip_y, 4.0)
|
|
305
|
-
|
|
306
|
-
# Test custom initval strategy
|
|
307
|
-
ips = make_initial_point_expression(
|
|
308
|
-
# Test that order does not matter
|
|
309
|
-
free_rvs=marginal_m.free_RVs[::-1],
|
|
310
|
-
rvs_to_transforms={},
|
|
311
|
-
initval_strategies={marginal_x: pt.constant(5.0)},
|
|
312
|
-
)
|
|
313
|
-
ip_y, ip_x = constant_fold(ips)
|
|
314
|
-
np.testing.assert_allclose(ip_x, 5.0)
|
|
315
|
-
np.testing.assert_allclose(ip_y, 10.0)
|
|
316
|
-
|
|
317
|
-
# Test logp
|
|
318
|
-
test_point = marginal_m.initial_point()
|
|
319
|
-
ref_logp_fn = m.compile_logp([m["idx"], m["x"], m["y"]])
|
|
320
|
-
ref_logp = logsumexp([ref_logp_fn({**test_point, **{"idx": idx}}) for idx in (0, 1)])
|
|
321
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
322
|
-
logp = marginal_m.compile_logp([marginal_m["x"], marginal_m["y"]])(test_point)
|
|
323
|
-
np.testing.assert_almost_equal(logp, ref_logp)
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
@pytest.mark.parametrize("advanced_indexing", (False, True))
|
|
327
|
-
def test_marginalized_index_as_key(advanced_indexing):
|
|
328
|
-
"""Test we can marginalize graphs where indexing is used as a mapping."""
|
|
329
|
-
|
|
330
|
-
w = [0.1, 0.3, 0.6]
|
|
331
|
-
mu = pt.as_tensor([-1, 0, 1])
|
|
332
|
-
|
|
333
|
-
if advanced_indexing:
|
|
334
|
-
y_val = pt.as_tensor([[-1, -1], [0, 1]])
|
|
335
|
-
shape = (2, 2)
|
|
336
|
-
else:
|
|
337
|
-
y_val = -1
|
|
338
|
-
shape = ()
|
|
339
|
-
|
|
340
|
-
with Model() as m:
|
|
341
|
-
x = pm.Categorical("x", p=w, shape=shape)
|
|
342
|
-
y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
|
|
343
|
-
|
|
344
|
-
marginal_m = marginalize(m, x)
|
|
345
|
-
|
|
346
|
-
marginal_logp = marginal_m.compile_logp(sum=False)({})[0]
|
|
347
|
-
ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
|
|
348
|
-
|
|
349
|
-
np.testing.assert_allclose(marginal_logp, ref_logp)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
def test_marginalized_index_as_value_and_key():
|
|
353
|
-
"""Test we can marginalize graphs were marginalized_rv is indexed."""
|
|
354
|
-
|
|
355
|
-
def build_model(build_batched: bool) -> Model:
|
|
356
|
-
with Model() as m:
|
|
357
|
-
if build_batched:
|
|
358
|
-
latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
|
|
359
|
-
else:
|
|
360
|
-
latent_state = pm.math.stack(
|
|
361
|
-
[pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)]
|
|
362
|
-
)
|
|
363
|
-
# latent state is used as the indexed variable
|
|
364
|
-
latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0])
|
|
365
|
-
picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6])
|
|
366
|
-
# picked intensity is used as the indexing variable
|
|
367
|
-
pm.Normal(
|
|
368
|
-
"intensity",
|
|
369
|
-
mu=latent_intensities[:, picked_intensity],
|
|
370
|
-
observed=[0.5, 1.5, 5.0, 15.0],
|
|
371
|
-
)
|
|
372
|
-
return m
|
|
373
|
-
|
|
374
|
-
# We compare with the equivalent but less efficient batched model
|
|
375
|
-
m = build_model(build_batched=True)
|
|
376
|
-
ref_m = build_model(build_batched=False)
|
|
377
|
-
|
|
378
|
-
m = marginalize(m, ["latent_state"])
|
|
379
|
-
ref_m = marginalize(ref_m, [f"latent_state_{i}" for i in range(4)])
|
|
380
|
-
test_point = {"picked_intensity": 1}
|
|
381
|
-
np.testing.assert_allclose(
|
|
382
|
-
m.compile_logp()(test_point),
|
|
383
|
-
ref_m.compile_logp()(test_point),
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
m = marginalize(m, ["picked_intensity"])
|
|
387
|
-
ref_m = marginalize(ref_m, ["picked_intensity"])
|
|
388
|
-
test_point = {}
|
|
389
|
-
np.testing.assert_allclose(
|
|
390
|
-
m.compile_logp()(test_point),
|
|
391
|
-
ref_m.compile_logp()(test_point),
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
class TestNotSupportedMixedDims:
|
|
396
|
-
"""Test lack of support for models where batch dims of marginalized variables are mixed."""
|
|
397
|
-
|
|
398
|
-
def test_mixed_dims_via_transposed_dot(self):
|
|
399
|
-
with Model() as m:
|
|
400
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
401
|
-
y = pm.Normal("y", mu=idx @ idx.T)
|
|
402
|
-
|
|
403
|
-
with pytest.raises(NotImplementedError):
|
|
404
|
-
marginalize(m, idx)
|
|
405
|
-
|
|
406
|
-
def test_mixed_dims_via_indexing(self):
|
|
407
|
-
mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]])
|
|
408
|
-
|
|
409
|
-
with Model() as m:
|
|
410
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
411
|
-
y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx])
|
|
412
|
-
with pytest.raises(NotImplementedError):
|
|
413
|
-
marginalize(m, idx)
|
|
414
|
-
|
|
415
|
-
with Model() as m:
|
|
416
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
417
|
-
y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx])
|
|
418
|
-
with pytest.raises(NotImplementedError):
|
|
419
|
-
marginalize(m, idx)
|
|
420
|
-
|
|
421
|
-
with Model() as m:
|
|
422
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
423
|
-
mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable(
|
|
424
|
-
mean[None, :][:, idx], 0
|
|
425
|
-
)
|
|
426
|
-
y = pm.Normal("y", mu=mu)
|
|
427
|
-
with pytest.raises(NotImplementedError):
|
|
428
|
-
marginalize(m, idx)
|
|
429
|
-
|
|
430
|
-
with Model() as m:
|
|
431
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
432
|
-
y = pm.Normal("y", mu=idx[0] + idx[1])
|
|
433
|
-
with pytest.raises(NotImplementedError):
|
|
434
|
-
marginalize(m, idx)
|
|
435
|
-
|
|
436
|
-
def test_mixed_dims_via_vector_indexing(self):
|
|
437
|
-
with Model() as m:
|
|
438
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
439
|
-
y = pm.Normal("y", mu=idx[[0, 1, 0, 0]])
|
|
440
|
-
with pytest.raises(NotImplementedError):
|
|
441
|
-
marginalize(m, idx)
|
|
442
|
-
|
|
443
|
-
with Model() as m:
|
|
444
|
-
idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2))
|
|
445
|
-
y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)])
|
|
446
|
-
with pytest.raises(NotImplementedError):
|
|
447
|
-
marginalize(m, idx)
|
|
448
|
-
|
|
449
|
-
def test_mixed_dims_via_support_dimension(self):
|
|
450
|
-
with Model() as m:
|
|
451
|
-
x = pm.Bernoulli("x", p=0.7, shape=3)
|
|
452
|
-
y = pm.Dirichlet("y", a=x * 10 + 1)
|
|
453
|
-
with pytest.raises(NotImplementedError):
|
|
454
|
-
marginalize(m, x)
|
|
455
|
-
|
|
456
|
-
def test_mixed_dims_via_nested_marginalization(self):
|
|
457
|
-
with Model() as m:
|
|
458
|
-
x = pm.Bernoulli("x", p=0.7, shape=(3,))
|
|
459
|
-
y = pm.Bernoulli("y", p=0.7, shape=(2,))
|
|
460
|
-
z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
|
|
461
|
-
|
|
462
|
-
with pytest.raises(NotImplementedError):
|
|
463
|
-
marginalize(m, [x, y])
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
def test_marginalized_deterministic_and_potential():
|
|
467
|
-
rng = np.random.default_rng(299)
|
|
468
|
-
|
|
469
|
-
with Model() as m:
|
|
470
|
-
x = pm.Bernoulli("x", p=0.7)
|
|
471
|
-
y = pm.Normal("y", x)
|
|
472
|
-
z = pm.Normal("z", x)
|
|
473
|
-
det = pm.Deterministic("det", y + z)
|
|
474
|
-
pot = pm.Potential("pot", y + z + 1)
|
|
475
|
-
|
|
476
|
-
marginal_m = marginalize(m, [x])
|
|
477
|
-
|
|
478
|
-
y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng)
|
|
479
|
-
np.testing.assert_almost_equal(y_draw + z_draw, det_draw)
|
|
480
|
-
np.testing.assert_almost_equal(det_draw, pot_draw - 1)
|
|
481
|
-
|
|
482
|
-
y_value = marginal_m.rvs_to_values[marginal_m["y"]]
|
|
483
|
-
z_value = marginal_m.rvs_to_values[marginal_m["z"]]
|
|
484
|
-
det_value, pot_value = marginal_m.replace_rvs_by_values([marginal_m["det"], marginal_m["pot"]])
|
|
485
|
-
assert set(inputvars([det_value, pot_value])) == {y_value, z_value}
|
|
486
|
-
assert det_value.eval({y_value: 2, z_value: 5}) == 7
|
|
487
|
-
assert pot_value.eval({y_value: 2, z_value: 5}) == 8
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
def test_not_supported_marginalized_deterministic_and_potential():
|
|
491
|
-
with Model() as m:
|
|
492
|
-
x = pm.Bernoulli("x", p=0.7)
|
|
493
|
-
y = pm.Normal("y", x)
|
|
494
|
-
det = pm.Deterministic("det", x + y)
|
|
495
|
-
|
|
496
|
-
with pytest.raises(
|
|
497
|
-
NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det"
|
|
498
|
-
):
|
|
499
|
-
marginalize(m, [x])
|
|
500
|
-
|
|
501
|
-
with Model() as m:
|
|
502
|
-
x = pm.Bernoulli("x", p=0.7)
|
|
503
|
-
y = pm.Normal("y", x)
|
|
504
|
-
pot = pm.Potential("pot", x + y)
|
|
505
|
-
|
|
506
|
-
with pytest.raises(
|
|
507
|
-
NotImplementedError, match="Cannot marginalize x due to dependent Potential pot"
|
|
508
|
-
):
|
|
509
|
-
marginalize(m, [x])
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
@pytest.mark.parametrize(
|
|
513
|
-
"transform, expected_warning",
|
|
514
|
-
(
|
|
515
|
-
(None, does_not_warn()),
|
|
516
|
-
(UNSET, does_not_warn()),
|
|
517
|
-
(transforms.log, does_not_warn()),
|
|
518
|
-
(transforms.Chain([transforms.logodds, transforms.log]), does_not_warn()),
|
|
519
|
-
(
|
|
520
|
-
transforms.Interval(0, 2),
|
|
521
|
-
pytest.warns(
|
|
522
|
-
UserWarning, match="which depends on the marginalized idx may no longer work"
|
|
523
|
-
),
|
|
524
|
-
),
|
|
525
|
-
(
|
|
526
|
-
transforms.Chain([transforms.log, transforms.Interval(-1, 1)]),
|
|
527
|
-
pytest.warns(
|
|
528
|
-
UserWarning, match="which depends on the marginalized idx may no longer work"
|
|
529
|
-
),
|
|
530
|
-
),
|
|
531
|
-
),
|
|
532
|
-
)
|
|
533
|
-
def test_marginalized_transforms(transform, expected_warning):
|
|
534
|
-
w = [0.1, 0.3, 0.6]
|
|
535
|
-
data = [0, 5, 10]
|
|
536
|
-
initval = 0.7 # Value that will be negative on the unconstrained space
|
|
537
|
-
|
|
538
|
-
with pm.Model() as m_ref:
|
|
539
|
-
sigma = pm.Mixture(
|
|
540
|
-
"sigma",
|
|
541
|
-
w=w,
|
|
542
|
-
comp_dists=pm.HalfNormal.dist([1, 2, 3]),
|
|
543
|
-
initval=initval,
|
|
544
|
-
default_transform=transform,
|
|
545
|
-
)
|
|
546
|
-
y = pm.Normal("y", 0, sigma, observed=data)
|
|
547
|
-
|
|
548
|
-
with Model() as m:
|
|
549
|
-
idx = pm.Categorical("idx", p=w)
|
|
550
|
-
sigma = pm.HalfNormal(
|
|
551
|
-
"sigma",
|
|
552
|
-
pt.switch(
|
|
553
|
-
pt.eq(idx, 0),
|
|
554
|
-
1,
|
|
555
|
-
pt.switch(
|
|
556
|
-
pt.eq(idx, 1),
|
|
557
|
-
2,
|
|
558
|
-
3,
|
|
559
|
-
),
|
|
560
|
-
),
|
|
561
|
-
default_transform=transform,
|
|
562
|
-
)
|
|
563
|
-
y = pm.Normal("y", 0, sigma, observed=data)
|
|
564
|
-
|
|
565
|
-
with expected_warning:
|
|
566
|
-
marginal_m = marginalize(m, [idx])
|
|
567
|
-
|
|
568
|
-
marginal_m.set_initval(marginal_m["sigma"], initval)
|
|
569
|
-
ip = marginal_m.initial_point()
|
|
570
|
-
if transform is not None:
|
|
571
|
-
if transform is UNSET:
|
|
572
|
-
transform_name = "log"
|
|
573
|
-
else:
|
|
574
|
-
transform_name = transform.name
|
|
575
|
-
assert -np.inf < ip[f"sigma_{transform_name}__"] < 0.0
|
|
576
|
-
np.testing.assert_allclose(marginal_m.compile_logp()(ip), m_ref.compile_logp()(ip))
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
def test_data_container():
|
|
580
|
-
"""Test that MarginalModel can handle Data containers."""
|
|
581
|
-
with Model(coords={"obs": [0]}) as m:
|
|
582
|
-
x = pm.Data("x", 2.5)
|
|
583
|
-
idx = pm.Bernoulli("idx", p=0.7, dims="obs")
|
|
584
|
-
y = pm.Normal("y", idx * x, dims="obs")
|
|
585
|
-
|
|
586
|
-
marginal_m = marginalize(m, [idx])
|
|
587
|
-
|
|
588
|
-
logp_fn = marginal_m.compile_logp()
|
|
589
|
-
|
|
590
|
-
with pm.Model(coords={"obs": [0]}) as m_ref:
|
|
591
|
-
x = pm.Data("x", 2.5)
|
|
592
|
-
y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs")
|
|
593
|
-
|
|
594
|
-
ref_logp_fn = m_ref.compile_logp()
|
|
595
|
-
|
|
596
|
-
for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1):
|
|
597
|
-
for m in (marginal_m, m_ref):
|
|
598
|
-
m.set_dim("obs", new_length=i, coord_values=tuple(range(i)))
|
|
599
|
-
pm.set_data({"x": x_val}, model=m)
|
|
600
|
-
|
|
601
|
-
ip = marginal_m.initial_point()
|
|
602
|
-
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
def test_mutable_indexing_jax_backend():
|
|
606
|
-
pytest.importorskip("jax")
|
|
607
|
-
from pymc.sampling.jax import get_jaxified_logp
|
|
608
|
-
|
|
609
|
-
with Model() as model:
|
|
610
|
-
data = pm.Data("data", np.zeros(10))
|
|
611
|
-
|
|
612
|
-
cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
|
|
613
|
-
cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5))
|
|
614
|
-
|
|
615
|
-
is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
|
|
616
|
-
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
|
|
617
|
-
marginal_model = marginalize(model, ["is_outlier"])
|
|
618
|
-
get_jaxified_logp(marginal_model)
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
class TestFullModels:
|
|
622
|
-
@pytest.fixture
|
|
623
|
-
def disaster_model(self):
|
|
624
|
-
# fmt: off
|
|
625
|
-
disaster_data = pd.Series(
|
|
626
|
-
[4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
|
|
627
|
-
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
|
|
628
|
-
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
|
|
629
|
-
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
|
|
630
|
-
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
|
|
631
|
-
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
|
|
632
|
-
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
|
|
633
|
-
)
|
|
634
|
-
# fmt: on
|
|
635
|
-
years = np.arange(1851, 1962)
|
|
636
|
-
|
|
637
|
-
with Model() as disaster_model:
|
|
638
|
-
switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
|
|
639
|
-
early_rate = pm.Exponential("early_rate", 1.0)
|
|
640
|
-
late_rate = pm.Exponential("late_rate", 1.0)
|
|
641
|
-
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
|
|
642
|
-
with pytest.warns(Warning):
|
|
643
|
-
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
|
|
644
|
-
|
|
645
|
-
return disaster_model, years
|
|
646
|
-
|
|
647
|
-
def test_change_point_model(self, disaster_model):
|
|
648
|
-
m, years = disaster_model
|
|
649
|
-
|
|
650
|
-
ip = m.initial_point()
|
|
651
|
-
ip["late_rate_log__"] += 1.0 # Make early and endpoint ip different
|
|
652
|
-
|
|
653
|
-
ip.pop("switchpoint")
|
|
654
|
-
ref_logp_fn = m.compile_logp(
|
|
655
|
-
[m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]]
|
|
656
|
-
)
|
|
657
|
-
ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years])
|
|
658
|
-
|
|
659
|
-
marginal_m = marginalize(m, m["switchpoint"])
|
|
660
|
-
|
|
661
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
662
|
-
marginal_m_logp = marginal_m.compile_logp(
|
|
663
|
-
[marginal_m["disasters_observed"], marginal_m["disasters_unobserved"]]
|
|
664
|
-
)(ip)
|
|
665
|
-
np.testing.assert_almost_equal(marginal_m_logp, ref_logp)
|
|
666
|
-
|
|
667
|
-
@pytest.mark.slow
|
|
668
|
-
def test_change_point_model_sampling(self, disaster_model):
|
|
669
|
-
m, _ = disaster_model
|
|
670
|
-
|
|
671
|
-
rng = np.random.default_rng(211)
|
|
672
|
-
|
|
673
|
-
with m:
|
|
674
|
-
before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
|
|
675
|
-
sample=("draw", "chain")
|
|
676
|
-
)
|
|
677
|
-
|
|
678
|
-
marginal_m = marginalize(m, "switchpoint")
|
|
679
|
-
|
|
680
|
-
with marginal_m:
|
|
681
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
682
|
-
after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
|
|
683
|
-
sample=("draw", "chain")
|
|
684
|
-
)
|
|
685
|
-
|
|
686
|
-
np.testing.assert_allclose(
|
|
687
|
-
before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2
|
|
688
|
-
)
|
|
689
|
-
np.testing.assert_allclose(
|
|
690
|
-
before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2
|
|
691
|
-
)
|
|
692
|
-
np.testing.assert_allclose(
|
|
693
|
-
before_marg["disasters_unobserved"].mean(),
|
|
694
|
-
after_marg["disasters_unobserved"].mean(),
|
|
695
|
-
rtol=1e-2,
|
|
696
|
-
)
|
|
697
|
-
|
|
698
|
-
@pytest.mark.parametrize("univariate", (True, False))
|
|
699
|
-
def test_vector_univariate_mixture(self, univariate):
|
|
700
|
-
with Model() as m:
|
|
701
|
-
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
|
|
702
|
-
|
|
703
|
-
def dist(idx, size):
|
|
704
|
-
return pm.math.switch(
|
|
705
|
-
pm.math.eq(idx, 0),
|
|
706
|
-
pm.Normal.dist([-10, -10], 1),
|
|
707
|
-
pm.Normal.dist([10, 10], 1),
|
|
708
|
-
)
|
|
709
|
-
|
|
710
|
-
pm.CustomDist("norm", idx, dist=dist)
|
|
711
|
-
|
|
712
|
-
marginal_m = marginalize(m, idx)
|
|
713
|
-
logp_fn = marginal_m.compile_logp()
|
|
714
|
-
|
|
715
|
-
if univariate:
|
|
716
|
-
with pm.Model() as ref_m:
|
|
717
|
-
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
|
|
718
|
-
else:
|
|
719
|
-
with pm.Model() as ref_m:
|
|
720
|
-
pm.Mixture(
|
|
721
|
-
"norm",
|
|
722
|
-
w=[0.5, 0.5],
|
|
723
|
-
comp_dists=[
|
|
724
|
-
pm.MvNormal.dist([-10, -10], np.eye(2)),
|
|
725
|
-
pm.MvNormal.dist([10, 10], np.eye(2)),
|
|
726
|
-
],
|
|
727
|
-
shape=(2,),
|
|
728
|
-
)
|
|
729
|
-
ref_logp_fn = ref_m.compile_logp()
|
|
730
|
-
|
|
731
|
-
for test_value in (
|
|
732
|
-
[-10, -10],
|
|
733
|
-
[10, 10],
|
|
734
|
-
[-10, 10],
|
|
735
|
-
[-10, 10],
|
|
736
|
-
):
|
|
737
|
-
pt = {"norm": test_value}
|
|
738
|
-
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
|
|
739
|
-
|
|
740
|
-
def test_k_censored_clusters_model(self):
|
|
741
|
-
data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
|
|
742
|
-
nobs = data.shape[0]
|
|
743
|
-
n_clusters = 5
|
|
744
|
-
|
|
745
|
-
def build_model(build_batched: bool) -> Model:
|
|
746
|
-
coords = {
|
|
747
|
-
"cluster": range(n_clusters),
|
|
748
|
-
"ndim": ("x", "y"),
|
|
749
|
-
"obs": range(nobs),
|
|
750
|
-
}
|
|
751
|
-
with Model(coords=coords) as m:
|
|
752
|
-
if build_batched:
|
|
753
|
-
idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
|
|
754
|
-
else:
|
|
755
|
-
idx = pm.math.stack(
|
|
756
|
-
[
|
|
757
|
-
pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters)
|
|
758
|
-
for i in range(nobs)
|
|
759
|
-
]
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
mu_x = pm.Normal(
|
|
763
|
-
"mu_x",
|
|
764
|
-
dims=["cluster"],
|
|
765
|
-
transform=ordered,
|
|
766
|
-
)
|
|
767
|
-
mu_y = pm.Normal("mu_y", dims=["cluster"])
|
|
768
|
-
mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
|
|
769
|
-
mu_indexed = mu[idx, :]
|
|
770
|
-
|
|
771
|
-
sigma = pm.HalfNormal("sigma")
|
|
772
|
-
|
|
773
|
-
y = pm.Censored(
|
|
774
|
-
"y",
|
|
775
|
-
dist=pm.Normal.dist(mu_indexed, sigma),
|
|
776
|
-
lower=-3,
|
|
777
|
-
upper=3,
|
|
778
|
-
observed=data,
|
|
779
|
-
dims=["obs", "ndim"],
|
|
780
|
-
)
|
|
781
|
-
|
|
782
|
-
return m
|
|
783
|
-
|
|
784
|
-
m = marginalize(build_model(build_batched=True), "idx")
|
|
785
|
-
m.set_initval(m["mu_x"], np.linspace(-1, 1, n_clusters))
|
|
786
|
-
|
|
787
|
-
ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(nobs)])
|
|
788
|
-
test_point = m.initial_point()
|
|
789
|
-
np.testing.assert_almost_equal(
|
|
790
|
-
m.compile_logp()(test_point),
|
|
791
|
-
ref_m.compile_logp()(test_point),
|
|
792
|
-
)
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
def test_unmarginalize():
|
|
796
|
-
with pm.Model() as m:
|
|
797
|
-
idx = pm.Bernoulli("idx", p=0.5)
|
|
798
|
-
sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx])
|
|
799
|
-
x = pm.Normal("x", mu=(idx + sub_idx) - 1)
|
|
800
|
-
|
|
801
|
-
marginal_m = marginalize(m, [idx, sub_idx])
|
|
802
|
-
assert not equivalent_models(marginal_m, m)
|
|
803
|
-
|
|
804
|
-
unmarginal_m = unmarginalize(marginal_m)
|
|
805
|
-
assert equivalent_models(unmarginal_m, m)
|
|
806
|
-
|
|
807
|
-
unmarginal_idx_explicit = unmarginalize(marginal_m, ("idx", "sub_idx"))
|
|
808
|
-
assert equivalent_models(unmarginal_idx_explicit, m)
|
|
809
|
-
|
|
810
|
-
# Test partial unmarginalize
|
|
811
|
-
unmarginal_idx = unmarginalize(marginal_m, "idx")
|
|
812
|
-
assert equivalent_models(unmarginal_idx, marginalize(m, "sub_idx"))
|
|
813
|
-
|
|
814
|
-
unmarginal_sub_idx = unmarginalize(marginal_m, "sub_idx")
|
|
815
|
-
assert equivalent_models(unmarginal_sub_idx, marginalize(m, "idx"))
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
class TestRecoverMarginals:
|
|
819
|
-
def test_basic(self):
|
|
820
|
-
with Model() as m:
|
|
821
|
-
sigma = pm.HalfNormal("sigma")
|
|
822
|
-
p = np.array([0.5, 0.2, 0.3])
|
|
823
|
-
k = pm.Categorical("k", p=p)
|
|
824
|
-
mu = np.array([-3.0, 0.0, 3.0])
|
|
825
|
-
mu_ = pt.as_tensor_variable(mu)
|
|
826
|
-
y = pm.Normal("y", mu=mu_[k], sigma=sigma)
|
|
827
|
-
|
|
828
|
-
marginal_m = marginalize(m, [k])
|
|
829
|
-
|
|
830
|
-
rng = np.random.default_rng(211)
|
|
831
|
-
|
|
832
|
-
with marginal_m:
|
|
833
|
-
prior = pm.sample_prior_predictive(
|
|
834
|
-
draws=20,
|
|
835
|
-
random_seed=rng,
|
|
836
|
-
return_inferencedata=False,
|
|
837
|
-
)
|
|
838
|
-
idata = InferenceData(posterior=dict_to_dataset(prior))
|
|
839
|
-
|
|
840
|
-
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
841
|
-
post = idata.posterior
|
|
842
|
-
assert "k" in post
|
|
843
|
-
assert "lp_k" in post
|
|
844
|
-
assert post.k.shape == post.y.shape
|
|
845
|
-
assert post.lp_k.shape == (*post.k.shape, len(p))
|
|
846
|
-
|
|
847
|
-
def true_logp(y, sigma):
|
|
848
|
-
y = y.repeat(len(p)).reshape(len(y), -1)
|
|
849
|
-
sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
|
|
850
|
-
return log_softmax(
|
|
851
|
-
np.log(p)
|
|
852
|
-
+ norm.logpdf(y, loc=mu, scale=sigma)
|
|
853
|
-
+ halfnorm.logpdf(sigma)
|
|
854
|
-
+ np.log(sigma),
|
|
855
|
-
axis=1,
|
|
856
|
-
)
|
|
857
|
-
|
|
858
|
-
np.testing.assert_almost_equal(
|
|
859
|
-
true_logp(post.y.values.flatten(), post.sigma.values.flatten()),
|
|
860
|
-
post.lp_k[0].values,
|
|
861
|
-
)
|
|
862
|
-
np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0)
|
|
863
|
-
|
|
864
|
-
def test_coords(self):
|
|
865
|
-
"""Test if coords can be recovered with marginalized value had it originally"""
|
|
866
|
-
with Model(coords={"year": [1990, 1991, 1992]}) as m:
|
|
867
|
-
sigma = pm.HalfNormal("sigma")
|
|
868
|
-
idx = pm.Bernoulli("idx", p=0.75, dims="year")
|
|
869
|
-
x = pm.Normal("x", mu=idx, sigma=sigma, dims="year")
|
|
870
|
-
|
|
871
|
-
marginal_m = marginalize(m, [idx])
|
|
872
|
-
rng = np.random.default_rng(211)
|
|
873
|
-
|
|
874
|
-
with marginal_m:
|
|
875
|
-
prior = pm.sample_prior_predictive(
|
|
876
|
-
draws=20,
|
|
877
|
-
random_seed=rng,
|
|
878
|
-
return_inferencedata=False,
|
|
879
|
-
)
|
|
880
|
-
idata = InferenceData(
|
|
881
|
-
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
|
|
882
|
-
)
|
|
883
|
-
|
|
884
|
-
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
885
|
-
post = idata.posterior
|
|
886
|
-
assert post.idx.dims == ("chain", "draw", "year")
|
|
887
|
-
assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
|
|
888
|
-
|
|
889
|
-
def test_batched(self):
|
|
890
|
-
"""Test that marginalization works for batched random variables"""
|
|
891
|
-
with Model() as m:
|
|
892
|
-
sigma = pm.HalfNormal("sigma")
|
|
893
|
-
idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
|
|
894
|
-
y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
|
|
895
|
-
|
|
896
|
-
marginal_m = marginalize(m, [idx])
|
|
897
|
-
|
|
898
|
-
rng = np.random.default_rng(211)
|
|
899
|
-
|
|
900
|
-
with marginal_m:
|
|
901
|
-
prior = pm.sample_prior_predictive(
|
|
902
|
-
draws=20,
|
|
903
|
-
random_seed=rng,
|
|
904
|
-
return_inferencedata=False,
|
|
905
|
-
)
|
|
906
|
-
idata = InferenceData(
|
|
907
|
-
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
|
|
908
|
-
)
|
|
909
|
-
|
|
910
|
-
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
911
|
-
post = idata.posterior
|
|
912
|
-
assert post["y"].shape == (1, 20, 2, 3)
|
|
913
|
-
assert post["idx"].shape == (1, 20, 3, 2)
|
|
914
|
-
assert post["lp_idx"].shape == (1, 20, 3, 2, 2)
|
|
915
|
-
|
|
916
|
-
def test_nested(self):
|
|
917
|
-
"""Test that marginalization works when there are nested marginalized RVs"""
|
|
918
|
-
|
|
919
|
-
with Model() as m:
|
|
920
|
-
idx = pm.Bernoulli("idx", p=0.75)
|
|
921
|
-
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
|
|
922
|
-
sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
|
|
923
|
-
|
|
924
|
-
marginal_m = marginalize(m, [idx, sub_idx])
|
|
925
|
-
|
|
926
|
-
rng = np.random.default_rng(211)
|
|
927
|
-
|
|
928
|
-
with marginal_m:
|
|
929
|
-
prior = pm.sample_prior_predictive(
|
|
930
|
-
draws=20,
|
|
931
|
-
random_seed=rng,
|
|
932
|
-
return_inferencedata=False,
|
|
933
|
-
)
|
|
934
|
-
idata = InferenceData(posterior=dict_to_dataset(prior))
|
|
935
|
-
|
|
936
|
-
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
937
|
-
post = idata.posterior
|
|
938
|
-
assert "idx" in post
|
|
939
|
-
assert "lp_idx" in post
|
|
940
|
-
assert post.idx.shape == post.y.shape
|
|
941
|
-
assert post.lp_idx.shape == (*post.idx.shape, 2)
|
|
942
|
-
assert "sub_idx" in post
|
|
943
|
-
assert "lp_sub_idx" in post
|
|
944
|
-
assert post.sub_idx.shape == post.y.shape
|
|
945
|
-
assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2)
|
|
946
|
-
|
|
947
|
-
def true_idx_logp(y):
|
|
948
|
-
idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
|
|
949
|
-
idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
|
|
950
|
-
return log_softmax(np.stack([idx_0, idx_1]).T, axis=1)
|
|
951
|
-
|
|
952
|
-
np.testing.assert_almost_equal(
|
|
953
|
-
true_idx_logp(post.y.values.flatten()),
|
|
954
|
-
post.lp_idx[0].values,
|
|
955
|
-
)
|
|
956
|
-
|
|
957
|
-
def true_sub_idx_logp(y):
|
|
958
|
-
sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1))
|
|
959
|
-
sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
|
|
960
|
-
return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1)
|
|
961
|
-
|
|
962
|
-
np.testing.assert_almost_equal(
|
|
963
|
-
true_sub_idx_logp(post.y.values.flatten()),
|
|
964
|
-
post.lp_sub_idx[0].values,
|
|
965
|
-
)
|
|
966
|
-
np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0)
|
|
967
|
-
np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0)
|