pymc-extras 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/distributions/timeseries.py +1 -1
- pymc_extras/inference/fit.py +0 -4
- pymc_extras/inference/pathfinder/__init__.py +3 -0
- pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
- pymc_extras/inference/pathfinder/lbfgs.py +190 -0
- pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
- pymc_extras/model/marginal/distributions.py +100 -3
- pymc_extras/model/marginal/graph_analysis.py +8 -9
- pymc_extras/model/marginal/marginal_model.py +437 -424
- pymc_extras/model/model_api.py +18 -2
- pymc_extras/statespace/core/statespace.py +79 -36
- pymc_extras/statespace/models/structural.py +21 -6
- pymc_extras/utils/model_equivalence.py +66 -0
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/METADATA +15 -5
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/RECORD +28 -24
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/WHEEL +1 -1
- tests/model/marginal/test_distributions.py +12 -11
- tests/model/marginal/test_marginal_model.py +301 -201
- tests/model/test_model_api.py +9 -0
- tests/statespace/test_statespace.py +54 -0
- tests/statespace/test_structural.py +10 -3
- tests/test_pathfinder.py +135 -7
- tests/test_pivoted_cholesky.py +1 -1
- tests/utils.py +0 -31
- pymc_extras/inference/pathfinder.py +0 -134
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -9,25 +9,28 @@ import pytensor.tensor as pt
|
|
|
9
9
|
import pytest
|
|
10
10
|
|
|
11
11
|
from arviz import InferenceData, dict_to_dataset
|
|
12
|
+
from pymc import Model, draw
|
|
12
13
|
from pymc.distributions import transforms
|
|
13
14
|
from pymc.distributions.transforms import ordered
|
|
14
|
-
from pymc.
|
|
15
|
-
from pymc.pytensorf import inputvars
|
|
15
|
+
from pymc.initial_point import make_initial_point_expression
|
|
16
|
+
from pymc.pytensorf import constant_fold, inputvars
|
|
16
17
|
from pymc.util import UNSET
|
|
17
18
|
from scipy.special import log_softmax, logsumexp
|
|
18
19
|
from scipy.stats import halfnorm, norm
|
|
19
20
|
|
|
21
|
+
from pymc_extras.model.marginal.distributions import MarginalRV
|
|
20
22
|
from pymc_extras.model.marginal.marginal_model import (
|
|
21
|
-
MarginalModel,
|
|
22
23
|
marginalize,
|
|
24
|
+
recover_marginals,
|
|
25
|
+
unmarginalize,
|
|
23
26
|
)
|
|
24
|
-
from
|
|
27
|
+
from pymc_extras.utils.model_equivalence import equivalent_models
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
def test_basic_marginalized_rv():
|
|
28
31
|
data = [2] * 5
|
|
29
32
|
|
|
30
|
-
with
|
|
33
|
+
with Model() as m:
|
|
31
34
|
sigma = pm.HalfNormal("sigma")
|
|
32
35
|
idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
|
|
33
36
|
mu = pt.switch(
|
|
@@ -42,79 +45,105 @@ def test_basic_marginalized_rv():
|
|
|
42
45
|
y = pm.Normal("y", mu=mu, sigma=sigma)
|
|
43
46
|
z = pm.Normal("z", y, observed=data)
|
|
44
47
|
|
|
45
|
-
|
|
46
|
-
assert
|
|
47
|
-
assert [rv.name for rv in
|
|
48
|
+
marginal_m = marginalize(m, [idx])
|
|
49
|
+
assert isinstance(marginal_m["y"].owner.op, MarginalRV)
|
|
50
|
+
assert ["idx"] not in [rv.name for rv in marginal_m.free_RVs]
|
|
51
|
+
|
|
52
|
+
# Test forward draws
|
|
53
|
+
y_draws, z_draws = draw(
|
|
54
|
+
[marginal_m["y"], marginal_m["z"]],
|
|
55
|
+
# Make sigma very small to make draws deterministic
|
|
56
|
+
givens={marginal_m["sigma"]: 0.001},
|
|
57
|
+
draws=1000,
|
|
58
|
+
random_seed=54,
|
|
59
|
+
)
|
|
60
|
+
assert sorted(np.unique(y_draws.round()) == [-1.0, 0.0, 1.0])
|
|
61
|
+
assert z_draws[y_draws < 0].mean() < z_draws[y_draws > 0].mean()
|
|
62
|
+
|
|
63
|
+
# Test initial_point
|
|
64
|
+
ips = make_initial_point_expression(
|
|
65
|
+
# Use basic_RVs to include the observed RV
|
|
66
|
+
free_rvs=marginal_m.basic_RVs,
|
|
67
|
+
rvs_to_transforms=marginal_m.rvs_to_transforms,
|
|
68
|
+
initval_strategies={},
|
|
69
|
+
)
|
|
70
|
+
# After simplification, we should have only constants in the graph (expect alloc which isn't constant folded):
|
|
71
|
+
ip_sigma, ip_y, ip_z = constant_fold(ips)
|
|
72
|
+
np.testing.assert_allclose(ip_sigma, 1.0)
|
|
73
|
+
np.testing.assert_allclose(ip_y, 1.0)
|
|
74
|
+
np.testing.assert_allclose(ip_z, np.full((5,), 1.0))
|
|
75
|
+
|
|
76
|
+
marginal_ip = marginal_m.initial_point()
|
|
77
|
+
expected_ip = m.initial_point()
|
|
78
|
+
expected_ip.pop("idx")
|
|
79
|
+
assert marginal_ip == expected_ip
|
|
48
80
|
|
|
49
81
|
# Test logp
|
|
50
|
-
with pm.Model() as
|
|
82
|
+
with pm.Model() as ref_m:
|
|
51
83
|
sigma = pm.HalfNormal("sigma")
|
|
52
84
|
y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
|
|
53
85
|
z = pm.Normal("z", y, observed=data)
|
|
54
86
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
ref_logp,
|
|
64
|
-
)
|
|
65
|
-
np.testing.assert_almost_equal(
|
|
66
|
-
m.compile_dlogp([m["y"]])(test_point),
|
|
67
|
-
ref_dlogp,
|
|
68
|
-
)
|
|
87
|
+
np.testing.assert_almost_equal(
|
|
88
|
+
marginal_m.compile_logp()(marginal_ip),
|
|
89
|
+
ref_m.compile_logp()(marginal_ip),
|
|
90
|
+
)
|
|
91
|
+
np.testing.assert_almost_equal(
|
|
92
|
+
marginal_m.compile_dlogp([marginal_m["y"]])(marginal_ip),
|
|
93
|
+
ref_m.compile_dlogp([ref_m["y"]])(marginal_ip),
|
|
94
|
+
)
|
|
69
95
|
|
|
70
96
|
|
|
71
97
|
def test_one_to_one_marginalized_rvs():
|
|
72
98
|
"""Test case with multiple, independent marginalized RVs."""
|
|
73
|
-
with
|
|
99
|
+
with Model() as m:
|
|
74
100
|
sigma = pm.HalfNormal("sigma")
|
|
75
101
|
idx1 = pm.Bernoulli("idx1", p=0.75)
|
|
76
102
|
x = pm.Normal("x", mu=idx1, sigma=sigma)
|
|
77
103
|
idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,))
|
|
78
104
|
y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,))
|
|
79
105
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
106
|
+
marginal_m = marginalize(m, [idx1, idx2])
|
|
107
|
+
assert isinstance(marginal_m["x"].owner.op, MarginalRV)
|
|
108
|
+
assert isinstance(marginal_m["y"].owner.op, MarginalRV)
|
|
109
|
+
assert marginal_m["x"].owner is not marginal_m["y"].owner
|
|
84
110
|
|
|
85
|
-
with pm.Model() as
|
|
111
|
+
with pm.Model() as ref_m:
|
|
86
112
|
sigma = pm.HalfNormal("sigma")
|
|
87
113
|
x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma)
|
|
88
114
|
y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,))
|
|
89
115
|
|
|
90
116
|
# Test logp
|
|
91
|
-
test_point =
|
|
92
|
-
x_logp, y_logp =
|
|
93
|
-
|
|
117
|
+
test_point = ref_m.initial_point()
|
|
118
|
+
x_logp, y_logp = marginal_m.compile_logp(vars=[marginal_m["x"], marginal_m["y"]], sum=False)(
|
|
119
|
+
test_point
|
|
120
|
+
)
|
|
121
|
+
x_ref_log, y_ref_logp = ref_m.compile_logp(vars=[ref_m["x"], ref_m["y"]], sum=False)(test_point)
|
|
94
122
|
np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum())
|
|
95
123
|
np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
|
|
96
124
|
|
|
97
125
|
|
|
98
126
|
def test_one_to_many_marginalized_rvs():
|
|
99
127
|
"""Test that marginalization works when there is more than one dependent RV"""
|
|
100
|
-
with
|
|
128
|
+
with Model() as m:
|
|
101
129
|
sigma = pm.HalfNormal("sigma")
|
|
102
130
|
idx = pm.Bernoulli("idx", p=0.75)
|
|
103
131
|
x = pm.Normal("x", mu=idx, sigma=sigma)
|
|
104
132
|
y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,))
|
|
105
133
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
109
|
-
m.marginalize([idx])
|
|
134
|
+
marginal_m = marginalize(m, [idx])
|
|
110
135
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
136
|
+
marginal_x = marginal_m["x"]
|
|
137
|
+
marginal_y = marginal_m["y"]
|
|
138
|
+
assert isinstance(marginal_x.owner.op, MarginalRV)
|
|
139
|
+
assert isinstance(marginal_y.owner.op, MarginalRV)
|
|
140
|
+
assert marginal_x.owner is marginal_y.owner
|
|
114
141
|
|
|
115
|
-
|
|
142
|
+
ref_logp_x_y_fn = m.compile_logp([idx, x, y])
|
|
143
|
+
tp = marginal_m.initial_point()
|
|
116
144
|
ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)])
|
|
117
|
-
|
|
145
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
146
|
+
logp_x_y = marginal_m.compile_logp([marginal_x, marginal_y])(tp)
|
|
118
147
|
np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
|
|
119
148
|
|
|
120
149
|
|
|
@@ -122,7 +151,7 @@ def test_one_to_many_unaligned_marginalized_rvs():
|
|
|
122
151
|
"""Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
|
|
123
152
|
|
|
124
153
|
def build_model(build_batched: bool):
|
|
125
|
-
with
|
|
154
|
+
with Model() as m:
|
|
126
155
|
if build_batched:
|
|
127
156
|
idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
|
|
128
157
|
else:
|
|
@@ -134,44 +163,41 @@ def test_one_to_many_unaligned_marginalized_rvs():
|
|
|
134
163
|
|
|
135
164
|
return m
|
|
136
165
|
|
|
137
|
-
|
|
138
|
-
ref_m = build_model(build_batched=False)
|
|
166
|
+
marginal_m = marginalize(build_model(build_batched=True), ["idx"])
|
|
167
|
+
ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(6)])
|
|
139
168
|
|
|
140
|
-
|
|
141
|
-
m.marginalize(["idx"])
|
|
142
|
-
ref_m.marginalize([f"idx_{i}" for i in range(6)])
|
|
169
|
+
test_point = marginal_m.initial_point()
|
|
143
170
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
171
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
172
|
+
np.testing.assert_allclose(
|
|
173
|
+
marginal_m.compile_logp()(test_point),
|
|
174
|
+
ref_m.compile_logp()(test_point),
|
|
175
|
+
)
|
|
149
176
|
|
|
150
177
|
|
|
151
178
|
def test_many_to_one_marginalized_rvs():
|
|
152
179
|
"""Test when random variables depend on multiple marginalized variables"""
|
|
153
|
-
with
|
|
180
|
+
with Model() as m:
|
|
154
181
|
x = pm.Bernoulli("x", 0.1)
|
|
155
182
|
y = pm.Bernoulli("y", 0.3)
|
|
156
183
|
z = pm.DiracDelta("z", c=x + y)
|
|
157
184
|
|
|
158
|
-
|
|
159
|
-
logp = m.compile_logp()
|
|
185
|
+
logp_fn = marginalize(m, [x, y]).compile_logp()
|
|
160
186
|
|
|
161
|
-
np.testing.assert_allclose(np.exp(
|
|
162
|
-
np.testing.assert_allclose(np.exp(
|
|
163
|
-
np.testing.assert_allclose(np.exp(
|
|
187
|
+
np.testing.assert_allclose(np.exp(logp_fn({"z": 0})), 0.9 * 0.7)
|
|
188
|
+
np.testing.assert_allclose(np.exp(logp_fn({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
|
|
189
|
+
np.testing.assert_allclose(np.exp(logp_fn({"z": 2})), 0.1 * 0.3)
|
|
164
190
|
|
|
165
191
|
|
|
166
192
|
@pytest.mark.parametrize("batched", (False, "left", "right"))
|
|
167
193
|
def test_nested_marginalized_rvs(batched):
|
|
168
194
|
"""Test that marginalization works when there are nested marginalized RVs"""
|
|
169
195
|
|
|
170
|
-
def build_model(build_batched: bool) ->
|
|
196
|
+
def build_model(build_batched: bool) -> Model:
|
|
171
197
|
idx_shape = (3,) if build_batched else ()
|
|
172
198
|
sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
|
|
173
199
|
|
|
174
|
-
with
|
|
200
|
+
with Model() as m:
|
|
175
201
|
sigma = pm.HalfNormal("sigma")
|
|
176
202
|
|
|
177
203
|
idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
|
|
@@ -186,10 +212,33 @@ def test_nested_marginalized_rvs(batched):
|
|
|
186
212
|
|
|
187
213
|
return m
|
|
188
214
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
215
|
+
marginal_m = marginalize(build_model(build_batched=batched), ["idx", "sub_idx"])
|
|
216
|
+
assert all(rv.name not in ("idx", "sub_idx") for rv in marginal_m.free_RVs)
|
|
217
|
+
|
|
218
|
+
# Test forward draws and initial_point, shouldn't depend on batching, so we only test one case
|
|
219
|
+
if not batched:
|
|
220
|
+
# Test forward draws
|
|
221
|
+
dep_draws, sub_dep_draws = draw(
|
|
222
|
+
[marginal_m["dep"], marginal_m["sub_dep"]],
|
|
223
|
+
# Make sigma very small to make draws deterministic
|
|
224
|
+
givens={marginal_m["sigma"]: 0.001},
|
|
225
|
+
draws=1000,
|
|
226
|
+
random_seed=214,
|
|
227
|
+
)
|
|
228
|
+
assert sorted(np.unique(dep_draws.round()) == [-1000.0, 1000.0])
|
|
229
|
+
assert sorted(np.unique(sub_dep_draws.round()) == [-1000.0, -900.0, 1000.0, 1100.0])
|
|
230
|
+
|
|
231
|
+
# Test initial_point
|
|
232
|
+
ips = make_initial_point_expression(
|
|
233
|
+
free_rvs=marginal_m.free_RVs,
|
|
234
|
+
rvs_to_transforms=marginal_m.rvs_to_transforms,
|
|
235
|
+
initval_strategies={},
|
|
236
|
+
)
|
|
237
|
+
# After simplification, we should have only constants in the graph
|
|
238
|
+
ip_sigma, ip_dep, ip_sub_dep = constant_fold(ips)
|
|
239
|
+
np.testing.assert_allclose(ip_sigma, 1.0)
|
|
240
|
+
np.testing.assert_allclose(ip_dep, 1000.0)
|
|
241
|
+
np.testing.assert_allclose(ip_sub_dep, np.full((5,), 1100.0))
|
|
193
242
|
|
|
194
243
|
# Test logp
|
|
195
244
|
ref_m = build_model(build_batched=False)
|
|
@@ -210,14 +259,70 @@ def test_nested_marginalized_rvs(batched):
|
|
|
210
259
|
if batched:
|
|
211
260
|
ref_logp *= 3
|
|
212
261
|
|
|
213
|
-
test_point =
|
|
262
|
+
test_point = marginal_m.initial_point()
|
|
214
263
|
test_point["dep"] = np.full_like(test_point["dep"], 1000)
|
|
215
264
|
test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
|
|
216
|
-
|
|
265
|
+
|
|
266
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
267
|
+
logp = marginal_m.compile_logp(vars=[marginal_m["dep"], marginal_m["sub_dep"]])(test_point)
|
|
217
268
|
|
|
218
269
|
np.testing.assert_almost_equal(logp, ref_logp)
|
|
219
270
|
|
|
220
271
|
|
|
272
|
+
def test_interdependent_rvs():
|
|
273
|
+
"""Test Marginalization when dependent RVs are interdependent."""
|
|
274
|
+
with Model() as m:
|
|
275
|
+
idx = pm.Bernoulli("idx", p=0.75)
|
|
276
|
+
x = pm.Normal("x", mu=idx * 2, sigma=1e-3)
|
|
277
|
+
# Y depends on both x and idx
|
|
278
|
+
y = pm.Normal("y", mu=x * idx * 2, sigma=1e-3)
|
|
279
|
+
|
|
280
|
+
marginal_m = marginalize(m, "idx")
|
|
281
|
+
|
|
282
|
+
marginal_x = marginal_m["x"]
|
|
283
|
+
marginal_y = marginal_m["y"]
|
|
284
|
+
assert isinstance(marginal_x.owner.op, MarginalRV)
|
|
285
|
+
assert isinstance(marginal_y.owner.op, MarginalRV)
|
|
286
|
+
assert marginal_x.owner is marginal_y.owner
|
|
287
|
+
|
|
288
|
+
# Test forward draws
|
|
289
|
+
x_draws, y_draws = draw([marginal_x, marginal_y], draws=1000, random_seed=54)
|
|
290
|
+
assert sorted(np.unique(x_draws.round())) == [0, 2]
|
|
291
|
+
assert sorted(np.unique(y_draws.round())) == [0, 4]
|
|
292
|
+
assert np.unique(y_draws[x_draws < 1].round()) == [0]
|
|
293
|
+
assert np.unique(y_draws[x_draws > 1].round()) == [4]
|
|
294
|
+
|
|
295
|
+
# Test initial_point
|
|
296
|
+
ips = make_initial_point_expression(
|
|
297
|
+
free_rvs=marginal_m.free_RVs,
|
|
298
|
+
rvs_to_transforms={},
|
|
299
|
+
initval_strategies={},
|
|
300
|
+
)
|
|
301
|
+
# After simplification, we should have only constants in the graph
|
|
302
|
+
ip_x, ip_y = constant_fold(ips)
|
|
303
|
+
np.testing.assert_allclose(ip_x, 2.0)
|
|
304
|
+
np.testing.assert_allclose(ip_y, 4.0)
|
|
305
|
+
|
|
306
|
+
# Test custom initval strategy
|
|
307
|
+
ips = make_initial_point_expression(
|
|
308
|
+
# Test that order does not matter
|
|
309
|
+
free_rvs=marginal_m.free_RVs[::-1],
|
|
310
|
+
rvs_to_transforms={},
|
|
311
|
+
initval_strategies={marginal_x: pt.constant(5.0)},
|
|
312
|
+
)
|
|
313
|
+
ip_y, ip_x = constant_fold(ips)
|
|
314
|
+
np.testing.assert_allclose(ip_x, 5.0)
|
|
315
|
+
np.testing.assert_allclose(ip_y, 10.0)
|
|
316
|
+
|
|
317
|
+
# Test logp
|
|
318
|
+
test_point = marginal_m.initial_point()
|
|
319
|
+
ref_logp_fn = m.compile_logp([m["idx"], m["x"], m["y"]])
|
|
320
|
+
ref_logp = logsumexp([ref_logp_fn({**test_point, **{"idx": idx}}) for idx in (0, 1)])
|
|
321
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
322
|
+
logp = marginal_m.compile_logp([marginal_m["x"], marginal_m["y"]])(test_point)
|
|
323
|
+
np.testing.assert_almost_equal(logp, ref_logp)
|
|
324
|
+
|
|
325
|
+
|
|
221
326
|
@pytest.mark.parametrize("advanced_indexing", (False, True))
|
|
222
327
|
def test_marginalized_index_as_key(advanced_indexing):
|
|
223
328
|
"""Test we can marginalize graphs where indexing is used as a mapping."""
|
|
@@ -232,13 +337,13 @@ def test_marginalized_index_as_key(advanced_indexing):
|
|
|
232
337
|
y_val = -1
|
|
233
338
|
shape = ()
|
|
234
339
|
|
|
235
|
-
with
|
|
340
|
+
with Model() as m:
|
|
236
341
|
x = pm.Categorical("x", p=w, shape=shape)
|
|
237
342
|
y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
|
|
238
343
|
|
|
239
|
-
|
|
344
|
+
marginal_m = marginalize(m, x)
|
|
240
345
|
|
|
241
|
-
marginal_logp =
|
|
346
|
+
marginal_logp = marginal_m.compile_logp(sum=False)({})[0]
|
|
242
347
|
ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
|
|
243
348
|
|
|
244
349
|
np.testing.assert_allclose(marginal_logp, ref_logp)
|
|
@@ -247,8 +352,8 @@ def test_marginalized_index_as_key(advanced_indexing):
|
|
|
247
352
|
def test_marginalized_index_as_value_and_key():
|
|
248
353
|
"""Test we can marginalize graphs were marginalized_rv is indexed."""
|
|
249
354
|
|
|
250
|
-
def build_model(build_batched: bool) ->
|
|
251
|
-
with
|
|
355
|
+
def build_model(build_batched: bool) -> Model:
|
|
356
|
+
with Model() as m:
|
|
252
357
|
if build_batched:
|
|
253
358
|
latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
|
|
254
359
|
else:
|
|
@@ -270,16 +375,16 @@ def test_marginalized_index_as_value_and_key():
|
|
|
270
375
|
m = build_model(build_batched=True)
|
|
271
376
|
ref_m = build_model(build_batched=False)
|
|
272
377
|
|
|
273
|
-
m
|
|
274
|
-
ref_m
|
|
378
|
+
m = marginalize(m, ["latent_state"])
|
|
379
|
+
ref_m = marginalize(ref_m, [f"latent_state_{i}" for i in range(4)])
|
|
275
380
|
test_point = {"picked_intensity": 1}
|
|
276
381
|
np.testing.assert_allclose(
|
|
277
382
|
m.compile_logp()(test_point),
|
|
278
383
|
ref_m.compile_logp()(test_point),
|
|
279
384
|
)
|
|
280
385
|
|
|
281
|
-
m
|
|
282
|
-
ref_m
|
|
386
|
+
m = marginalize(m, ["picked_intensity"])
|
|
387
|
+
ref_m = marginalize(ref_m, ["picked_intensity"])
|
|
283
388
|
test_point = {}
|
|
284
389
|
np.testing.assert_allclose(
|
|
285
390
|
m.compile_logp()(test_point),
|
|
@@ -291,99 +396,99 @@ class TestNotSupportedMixedDims:
|
|
|
291
396
|
"""Test lack of support for models where batch dims of marginalized variables are mixed."""
|
|
292
397
|
|
|
293
398
|
def test_mixed_dims_via_transposed_dot(self):
|
|
294
|
-
with
|
|
399
|
+
with Model() as m:
|
|
295
400
|
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
296
401
|
y = pm.Normal("y", mu=idx @ idx.T)
|
|
297
|
-
|
|
298
|
-
|
|
402
|
+
|
|
403
|
+
with pytest.raises(NotImplementedError):
|
|
404
|
+
marginalize(m, idx)
|
|
299
405
|
|
|
300
406
|
def test_mixed_dims_via_indexing(self):
|
|
301
407
|
mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]])
|
|
302
408
|
|
|
303
|
-
with
|
|
409
|
+
with Model() as m:
|
|
304
410
|
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
305
411
|
y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx])
|
|
306
|
-
|
|
307
|
-
|
|
412
|
+
with pytest.raises(NotImplementedError):
|
|
413
|
+
marginalize(m, idx)
|
|
308
414
|
|
|
309
|
-
with
|
|
415
|
+
with Model() as m:
|
|
310
416
|
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
311
417
|
y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx])
|
|
312
|
-
|
|
313
|
-
|
|
418
|
+
with pytest.raises(NotImplementedError):
|
|
419
|
+
marginalize(m, idx)
|
|
314
420
|
|
|
315
|
-
with
|
|
421
|
+
with Model() as m:
|
|
316
422
|
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
317
423
|
mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable(
|
|
318
424
|
mean[None, :][:, idx], 0
|
|
319
425
|
)
|
|
320
426
|
y = pm.Normal("y", mu=mu)
|
|
321
|
-
|
|
322
|
-
|
|
427
|
+
with pytest.raises(NotImplementedError):
|
|
428
|
+
marginalize(m, idx)
|
|
323
429
|
|
|
324
|
-
with
|
|
430
|
+
with Model() as m:
|
|
325
431
|
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
326
432
|
y = pm.Normal("y", mu=idx[0] + idx[1])
|
|
327
|
-
|
|
328
|
-
|
|
433
|
+
with pytest.raises(NotImplementedError):
|
|
434
|
+
marginalize(m, idx)
|
|
329
435
|
|
|
330
436
|
def test_mixed_dims_via_vector_indexing(self):
|
|
331
|
-
with
|
|
437
|
+
with Model() as m:
|
|
332
438
|
idx = pm.Bernoulli("idx", p=0.7, shape=2)
|
|
333
439
|
y = pm.Normal("y", mu=idx[[0, 1, 0, 0]])
|
|
334
|
-
|
|
335
|
-
|
|
440
|
+
with pytest.raises(NotImplementedError):
|
|
441
|
+
marginalize(m, idx)
|
|
336
442
|
|
|
337
|
-
with
|
|
443
|
+
with Model() as m:
|
|
338
444
|
idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2))
|
|
339
445
|
y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)])
|
|
340
|
-
|
|
341
|
-
|
|
446
|
+
with pytest.raises(NotImplementedError):
|
|
447
|
+
marginalize(m, idx)
|
|
342
448
|
|
|
343
449
|
def test_mixed_dims_via_support_dimension(self):
|
|
344
|
-
with
|
|
450
|
+
with Model() as m:
|
|
345
451
|
x = pm.Bernoulli("x", p=0.7, shape=3)
|
|
346
452
|
y = pm.Dirichlet("y", a=x * 10 + 1)
|
|
347
|
-
|
|
348
|
-
|
|
453
|
+
with pytest.raises(NotImplementedError):
|
|
454
|
+
marginalize(m, x)
|
|
349
455
|
|
|
350
456
|
def test_mixed_dims_via_nested_marginalization(self):
|
|
351
|
-
with
|
|
457
|
+
with Model() as m:
|
|
352
458
|
x = pm.Bernoulli("x", p=0.7, shape=(3,))
|
|
353
459
|
y = pm.Bernoulli("y", p=0.7, shape=(2,))
|
|
354
460
|
z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
|
|
355
461
|
|
|
356
|
-
|
|
357
|
-
|
|
462
|
+
with pytest.raises(NotImplementedError):
|
|
463
|
+
marginalize(m, [x, y])
|
|
358
464
|
|
|
359
465
|
|
|
360
466
|
def test_marginalized_deterministic_and_potential():
|
|
361
467
|
rng = np.random.default_rng(299)
|
|
362
468
|
|
|
363
|
-
with
|
|
469
|
+
with Model() as m:
|
|
364
470
|
x = pm.Bernoulli("x", p=0.7)
|
|
365
471
|
y = pm.Normal("y", x)
|
|
366
472
|
z = pm.Normal("z", x)
|
|
367
473
|
det = pm.Deterministic("det", y + z)
|
|
368
474
|
pot = pm.Potential("pot", y + z + 1)
|
|
369
475
|
|
|
370
|
-
|
|
371
|
-
m.marginalize([x])
|
|
476
|
+
marginal_m = marginalize(m, [x])
|
|
372
477
|
|
|
373
478
|
y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng)
|
|
374
479
|
np.testing.assert_almost_equal(y_draw + z_draw, det_draw)
|
|
375
480
|
np.testing.assert_almost_equal(det_draw, pot_draw - 1)
|
|
376
481
|
|
|
377
|
-
y_value =
|
|
378
|
-
z_value =
|
|
379
|
-
det_value, pot_value =
|
|
482
|
+
y_value = marginal_m.rvs_to_values[marginal_m["y"]]
|
|
483
|
+
z_value = marginal_m.rvs_to_values[marginal_m["z"]]
|
|
484
|
+
det_value, pot_value = marginal_m.replace_rvs_by_values([marginal_m["det"], marginal_m["pot"]])
|
|
380
485
|
assert set(inputvars([det_value, pot_value])) == {y_value, z_value}
|
|
381
486
|
assert det_value.eval({y_value: 2, z_value: 5}) == 7
|
|
382
487
|
assert pot_value.eval({y_value: 2, z_value: 5}) == 8
|
|
383
488
|
|
|
384
489
|
|
|
385
490
|
def test_not_supported_marginalized_deterministic_and_potential():
|
|
386
|
-
with
|
|
491
|
+
with Model() as m:
|
|
387
492
|
x = pm.Bernoulli("x", p=0.7)
|
|
388
493
|
y = pm.Normal("y", x)
|
|
389
494
|
det = pm.Deterministic("det", x + y)
|
|
@@ -391,9 +496,9 @@ def test_not_supported_marginalized_deterministic_and_potential():
|
|
|
391
496
|
with pytest.raises(
|
|
392
497
|
NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det"
|
|
393
498
|
):
|
|
394
|
-
|
|
499
|
+
marginalize(m, [x])
|
|
395
500
|
|
|
396
|
-
with
|
|
501
|
+
with Model() as m:
|
|
397
502
|
x = pm.Bernoulli("x", p=0.7)
|
|
398
503
|
y = pm.Normal("y", x)
|
|
399
504
|
pot = pm.Potential("pot", x + y)
|
|
@@ -401,7 +506,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
|
|
|
401
506
|
with pytest.raises(
|
|
402
507
|
NotImplementedError, match="Cannot marginalize x due to dependent Potential pot"
|
|
403
508
|
):
|
|
404
|
-
|
|
509
|
+
marginalize(m, [x])
|
|
405
510
|
|
|
406
511
|
|
|
407
512
|
@pytest.mark.parametrize(
|
|
@@ -410,15 +515,15 @@ def test_not_supported_marginalized_deterministic_and_potential():
|
|
|
410
515
|
(None, does_not_warn()),
|
|
411
516
|
(UNSET, does_not_warn()),
|
|
412
517
|
(transforms.log, does_not_warn()),
|
|
413
|
-
(transforms.Chain([transforms.
|
|
518
|
+
(transforms.Chain([transforms.logodds, transforms.log]), does_not_warn()),
|
|
414
519
|
(
|
|
415
|
-
transforms.Interval(0,
|
|
520
|
+
transforms.Interval(0, 2),
|
|
416
521
|
pytest.warns(
|
|
417
522
|
UserWarning, match="which depends on the marginalized idx may no longer work"
|
|
418
523
|
),
|
|
419
524
|
),
|
|
420
525
|
(
|
|
421
|
-
transforms.Chain([transforms.log, transforms.Interval(
|
|
526
|
+
transforms.Chain([transforms.log, transforms.Interval(-1, 1)]),
|
|
422
527
|
pytest.warns(
|
|
423
528
|
UserWarning, match="which depends on the marginalized idx may no longer work"
|
|
424
529
|
),
|
|
@@ -428,7 +533,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
|
|
|
428
533
|
def test_marginalized_transforms(transform, expected_warning):
|
|
429
534
|
w = [0.1, 0.3, 0.6]
|
|
430
535
|
data = [0, 5, 10]
|
|
431
|
-
initval = 0.
|
|
536
|
+
initval = 0.7 # Value that will be negative on the unconstrained space
|
|
432
537
|
|
|
433
538
|
with pm.Model() as m_ref:
|
|
434
539
|
sigma = pm.Mixture(
|
|
@@ -440,7 +545,7 @@ def test_marginalized_transforms(transform, expected_warning):
|
|
|
440
545
|
)
|
|
441
546
|
y = pm.Normal("y", 0, sigma, observed=data)
|
|
442
547
|
|
|
443
|
-
with
|
|
548
|
+
with Model() as m:
|
|
444
549
|
idx = pm.Categorical("idx", p=w)
|
|
445
550
|
sigma = pm.HalfNormal(
|
|
446
551
|
"sigma",
|
|
@@ -453,32 +558,32 @@ def test_marginalized_transforms(transform, expected_warning):
|
|
|
453
558
|
3,
|
|
454
559
|
),
|
|
455
560
|
),
|
|
456
|
-
initval=initval,
|
|
457
561
|
default_transform=transform,
|
|
458
562
|
)
|
|
459
563
|
y = pm.Normal("y", 0, sigma, observed=data)
|
|
460
564
|
|
|
461
565
|
with expected_warning:
|
|
462
|
-
|
|
566
|
+
marginal_m = marginalize(m, [idx])
|
|
463
567
|
|
|
464
|
-
|
|
568
|
+
marginal_m.set_initval(marginal_m["sigma"], initval)
|
|
569
|
+
ip = marginal_m.initial_point()
|
|
465
570
|
if transform is not None:
|
|
466
571
|
if transform is UNSET:
|
|
467
572
|
transform_name = "log"
|
|
468
573
|
else:
|
|
469
574
|
transform_name = transform.name
|
|
470
|
-
assert f"sigma_{transform_name}__"
|
|
471
|
-
np.testing.assert_allclose(
|
|
575
|
+
assert -np.inf < ip[f"sigma_{transform_name}__"] < 0.0
|
|
576
|
+
np.testing.assert_allclose(marginal_m.compile_logp()(ip), m_ref.compile_logp()(ip))
|
|
472
577
|
|
|
473
578
|
|
|
474
579
|
def test_data_container():
|
|
475
580
|
"""Test that MarginalModel can handle Data containers."""
|
|
476
|
-
with
|
|
581
|
+
with Model(coords={"obs": [0]}) as m:
|
|
477
582
|
x = pm.Data("x", 2.5)
|
|
478
583
|
idx = pm.Bernoulli("idx", p=0.7, dims="obs")
|
|
479
584
|
y = pm.Normal("y", idx * x, dims="obs")
|
|
480
585
|
|
|
481
|
-
marginal_m
|
|
586
|
+
marginal_m = marginalize(m, [idx])
|
|
482
587
|
|
|
483
588
|
logp_fn = marginal_m.compile_logp()
|
|
484
589
|
|
|
@@ -501,7 +606,7 @@ def test_mutable_indexing_jax_backend():
|
|
|
501
606
|
pytest.importorskip("jax")
|
|
502
607
|
from pymc.sampling.jax import get_jaxified_logp
|
|
503
608
|
|
|
504
|
-
with
|
|
609
|
+
with Model() as model:
|
|
505
610
|
data = pm.Data("data", np.zeros(10))
|
|
506
611
|
|
|
507
612
|
cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
|
|
@@ -509,38 +614,8 @@ def test_mutable_indexing_jax_backend():
|
|
|
509
614
|
|
|
510
615
|
is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
|
|
511
616
|
pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
|
|
512
|
-
|
|
513
|
-
get_jaxified_logp(
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
def test_marginal_model_func():
|
|
517
|
-
def create_model(model_class):
|
|
518
|
-
with model_class(coords={"trial": range(10)}) as m:
|
|
519
|
-
idx = pm.Bernoulli("idx", p=0.5, dims="trial")
|
|
520
|
-
mu = pt.where(idx, 1, -1)
|
|
521
|
-
sigma = pm.HalfNormal("sigma")
|
|
522
|
-
y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
|
|
523
|
-
return m
|
|
524
|
-
|
|
525
|
-
marginal_m = marginalize(create_model(pm.Model), ["idx"])
|
|
526
|
-
assert isinstance(marginal_m, MarginalModel)
|
|
527
|
-
|
|
528
|
-
reference_m = create_model(MarginalModel)
|
|
529
|
-
reference_m.marginalize(["idx"])
|
|
530
|
-
|
|
531
|
-
# Check forward graph representation is the same
|
|
532
|
-
marginal_fgraph, _ = fgraph_from_model(marginal_m)
|
|
533
|
-
reference_fgraph, _ = fgraph_from_model(reference_m)
|
|
534
|
-
assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)
|
|
535
|
-
|
|
536
|
-
# Check logp graph is the same
|
|
537
|
-
# This fails because OpFromGraphs comparison is broken
|
|
538
|
-
# assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
|
|
539
|
-
ip = marginal_m.initial_point()
|
|
540
|
-
np.testing.assert_allclose(
|
|
541
|
-
marginal_m.compile_logp()(ip),
|
|
542
|
-
reference_m.compile_logp()(ip),
|
|
543
|
-
)
|
|
617
|
+
marginal_model = marginalize(model, ["is_outlier"])
|
|
618
|
+
get_jaxified_logp(marginal_model)
|
|
544
619
|
|
|
545
620
|
|
|
546
621
|
class TestFullModels:
|
|
@@ -559,10 +634,10 @@ class TestFullModels:
|
|
|
559
634
|
# fmt: on
|
|
560
635
|
years = np.arange(1851, 1962)
|
|
561
636
|
|
|
562
|
-
with
|
|
637
|
+
with Model() as disaster_model:
|
|
563
638
|
switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
|
|
564
|
-
early_rate = pm.Exponential("early_rate", 1.0
|
|
565
|
-
late_rate = pm.Exponential("late_rate", 1.0
|
|
639
|
+
early_rate = pm.Exponential("early_rate", 1.0)
|
|
640
|
+
late_rate = pm.Exponential("late_rate", 1.0)
|
|
566
641
|
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
|
|
567
642
|
with pytest.warns(Warning):
|
|
568
643
|
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
|
|
@@ -573,17 +648,21 @@ class TestFullModels:
|
|
|
573
648
|
m, years = disaster_model
|
|
574
649
|
|
|
575
650
|
ip = m.initial_point()
|
|
651
|
+
ip["late_rate_log__"] += 1.0 # Make early and endpoint ip different
|
|
652
|
+
|
|
576
653
|
ip.pop("switchpoint")
|
|
577
654
|
ref_logp_fn = m.compile_logp(
|
|
578
655
|
[m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]]
|
|
579
656
|
)
|
|
580
657
|
ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years])
|
|
581
658
|
|
|
582
|
-
|
|
583
|
-
m.marginalize(m["switchpoint"])
|
|
659
|
+
marginal_m = marginalize(m, m["switchpoint"])
|
|
584
660
|
|
|
585
|
-
|
|
586
|
-
|
|
661
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
662
|
+
marginal_m_logp = marginal_m.compile_logp(
|
|
663
|
+
[marginal_m["disasters_observed"], marginal_m["disasters_unobserved"]]
|
|
664
|
+
)(ip)
|
|
665
|
+
np.testing.assert_almost_equal(marginal_m_logp, ref_logp)
|
|
587
666
|
|
|
588
667
|
@pytest.mark.slow
|
|
589
668
|
def test_change_point_model_sampling(self, disaster_model):
|
|
@@ -596,13 +675,13 @@ class TestFullModels:
|
|
|
596
675
|
sample=("draw", "chain")
|
|
597
676
|
)
|
|
598
677
|
|
|
599
|
-
|
|
600
|
-
m.marginalize([m["switchpoint"]])
|
|
678
|
+
marginal_m = marginalize(m, "switchpoint")
|
|
601
679
|
|
|
602
|
-
with
|
|
603
|
-
|
|
604
|
-
sample=
|
|
605
|
-
|
|
680
|
+
with marginal_m:
|
|
681
|
+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
|
|
682
|
+
after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
|
|
683
|
+
sample=("draw", "chain")
|
|
684
|
+
)
|
|
606
685
|
|
|
607
686
|
np.testing.assert_allclose(
|
|
608
687
|
before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2
|
|
@@ -618,7 +697,7 @@ class TestFullModels:
|
|
|
618
697
|
|
|
619
698
|
@pytest.mark.parametrize("univariate", (True, False))
|
|
620
699
|
def test_vector_univariate_mixture(self, univariate):
|
|
621
|
-
with
|
|
700
|
+
with Model() as m:
|
|
622
701
|
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
|
|
623
702
|
|
|
624
703
|
def dist(idx, size):
|
|
@@ -630,8 +709,8 @@ class TestFullModels:
|
|
|
630
709
|
|
|
631
710
|
pm.CustomDist("norm", idx, dist=dist)
|
|
632
711
|
|
|
633
|
-
|
|
634
|
-
logp_fn =
|
|
712
|
+
marginal_m = marginalize(m, idx)
|
|
713
|
+
logp_fn = marginal_m.compile_logp()
|
|
635
714
|
|
|
636
715
|
if univariate:
|
|
637
716
|
with pm.Model() as ref_m:
|
|
@@ -659,16 +738,17 @@ class TestFullModels:
|
|
|
659
738
|
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
|
|
660
739
|
|
|
661
740
|
def test_k_censored_clusters_model(self):
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
741
|
+
data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
|
|
742
|
+
nobs = data.shape[0]
|
|
743
|
+
n_clusters = 5
|
|
744
|
+
|
|
745
|
+
def build_model(build_batched: bool) -> Model:
|
|
666
746
|
coords = {
|
|
667
747
|
"cluster": range(n_clusters),
|
|
668
748
|
"ndim": ("x", "y"),
|
|
669
749
|
"obs": range(nobs),
|
|
670
750
|
}
|
|
671
|
-
with
|
|
751
|
+
with Model(coords=coords) as m:
|
|
672
752
|
if build_batched:
|
|
673
753
|
idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
|
|
674
754
|
else:
|
|
@@ -683,7 +763,6 @@ class TestFullModels:
|
|
|
683
763
|
"mu_x",
|
|
684
764
|
dims=["cluster"],
|
|
685
765
|
transform=ordered,
|
|
686
|
-
initval=np.linspace(-1, 1, n_clusters),
|
|
687
766
|
)
|
|
688
767
|
mu_y = pm.Normal("mu_y", dims=["cluster"])
|
|
689
768
|
mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
|
|
@@ -702,12 +781,10 @@ class TestFullModels:
|
|
|
702
781
|
|
|
703
782
|
return m
|
|
704
783
|
|
|
705
|
-
m = build_model(build_batched=True)
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
m.marginalize([m["idx"]])
|
|
709
|
-
ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")])
|
|
784
|
+
m = marginalize(build_model(build_batched=True), "idx")
|
|
785
|
+
m.set_initval(m["mu_x"], np.linspace(-1, 1, n_clusters))
|
|
710
786
|
|
|
787
|
+
ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(nobs)])
|
|
711
788
|
test_point = m.initial_point()
|
|
712
789
|
np.testing.assert_almost_equal(
|
|
713
790
|
m.compile_logp()(test_point),
|
|
@@ -715,9 +792,32 @@ class TestFullModels:
|
|
|
715
792
|
)
|
|
716
793
|
|
|
717
794
|
|
|
795
|
+
def test_unmarginalize():
|
|
796
|
+
with pm.Model() as m:
|
|
797
|
+
idx = pm.Bernoulli("idx", p=0.5)
|
|
798
|
+
sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx])
|
|
799
|
+
x = pm.Normal("x", mu=(idx + sub_idx) - 1)
|
|
800
|
+
|
|
801
|
+
marginal_m = marginalize(m, [idx, sub_idx])
|
|
802
|
+
assert not equivalent_models(marginal_m, m)
|
|
803
|
+
|
|
804
|
+
unmarginal_m = unmarginalize(marginal_m)
|
|
805
|
+
assert equivalent_models(unmarginal_m, m)
|
|
806
|
+
|
|
807
|
+
unmarginal_idx_explicit = unmarginalize(marginal_m, ("idx", "sub_idx"))
|
|
808
|
+
assert equivalent_models(unmarginal_idx_explicit, m)
|
|
809
|
+
|
|
810
|
+
# Test partial unmarginalize
|
|
811
|
+
unmarginal_idx = unmarginalize(marginal_m, "idx")
|
|
812
|
+
assert equivalent_models(unmarginal_idx, marginalize(m, "sub_idx"))
|
|
813
|
+
|
|
814
|
+
unmarginal_sub_idx = unmarginalize(marginal_m, "sub_idx")
|
|
815
|
+
assert equivalent_models(unmarginal_sub_idx, marginalize(m, "idx"))
|
|
816
|
+
|
|
817
|
+
|
|
718
818
|
class TestRecoverMarginals:
|
|
719
819
|
def test_basic(self):
|
|
720
|
-
with
|
|
820
|
+
with Model() as m:
|
|
721
821
|
sigma = pm.HalfNormal("sigma")
|
|
722
822
|
p = np.array([0.5, 0.2, 0.3])
|
|
723
823
|
k = pm.Categorical("k", p=p)
|
|
@@ -725,11 +825,11 @@ class TestRecoverMarginals:
|
|
|
725
825
|
mu_ = pt.as_tensor_variable(mu)
|
|
726
826
|
y = pm.Normal("y", mu=mu_[k], sigma=sigma)
|
|
727
827
|
|
|
728
|
-
|
|
828
|
+
marginal_m = marginalize(m, [k])
|
|
729
829
|
|
|
730
830
|
rng = np.random.default_rng(211)
|
|
731
831
|
|
|
732
|
-
with
|
|
832
|
+
with marginal_m:
|
|
733
833
|
prior = pm.sample_prior_predictive(
|
|
734
834
|
draws=20,
|
|
735
835
|
random_seed=rng,
|
|
@@ -737,7 +837,7 @@ class TestRecoverMarginals:
|
|
|
737
837
|
)
|
|
738
838
|
idata = InferenceData(posterior=dict_to_dataset(prior))
|
|
739
839
|
|
|
740
|
-
idata =
|
|
840
|
+
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
741
841
|
post = idata.posterior
|
|
742
842
|
assert "k" in post
|
|
743
843
|
assert "lp_k" in post
|
|
@@ -763,15 +863,15 @@ class TestRecoverMarginals:
|
|
|
763
863
|
|
|
764
864
|
def test_coords(self):
|
|
765
865
|
"""Test if coords can be recovered with marginalized value had it originally"""
|
|
766
|
-
with
|
|
866
|
+
with Model(coords={"year": [1990, 1991, 1992]}) as m:
|
|
767
867
|
sigma = pm.HalfNormal("sigma")
|
|
768
868
|
idx = pm.Bernoulli("idx", p=0.75, dims="year")
|
|
769
869
|
x = pm.Normal("x", mu=idx, sigma=sigma, dims="year")
|
|
770
870
|
|
|
771
|
-
|
|
871
|
+
marginal_m = marginalize(m, [idx])
|
|
772
872
|
rng = np.random.default_rng(211)
|
|
773
873
|
|
|
774
|
-
with
|
|
874
|
+
with marginal_m:
|
|
775
875
|
prior = pm.sample_prior_predictive(
|
|
776
876
|
draws=20,
|
|
777
877
|
random_seed=rng,
|
|
@@ -781,23 +881,23 @@ class TestRecoverMarginals:
|
|
|
781
881
|
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
|
|
782
882
|
)
|
|
783
883
|
|
|
784
|
-
idata =
|
|
884
|
+
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
785
885
|
post = idata.posterior
|
|
786
886
|
assert post.idx.dims == ("chain", "draw", "year")
|
|
787
887
|
assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
|
|
788
888
|
|
|
789
889
|
def test_batched(self):
|
|
790
890
|
"""Test that marginalization works for batched random variables"""
|
|
791
|
-
with
|
|
891
|
+
with Model() as m:
|
|
792
892
|
sigma = pm.HalfNormal("sigma")
|
|
793
893
|
idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
|
|
794
894
|
y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
|
|
795
895
|
|
|
796
|
-
|
|
896
|
+
marginal_m = marginalize(m, [idx])
|
|
797
897
|
|
|
798
898
|
rng = np.random.default_rng(211)
|
|
799
899
|
|
|
800
|
-
with
|
|
900
|
+
with marginal_m:
|
|
801
901
|
prior = pm.sample_prior_predictive(
|
|
802
902
|
draws=20,
|
|
803
903
|
random_seed=rng,
|
|
@@ -807,7 +907,7 @@ class TestRecoverMarginals:
|
|
|
807
907
|
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
|
|
808
908
|
)
|
|
809
909
|
|
|
810
|
-
idata =
|
|
910
|
+
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
811
911
|
post = idata.posterior
|
|
812
912
|
assert post["y"].shape == (1, 20, 2, 3)
|
|
813
913
|
assert post["idx"].shape == (1, 20, 3, 2)
|
|
@@ -816,16 +916,16 @@ class TestRecoverMarginals:
|
|
|
816
916
|
def test_nested(self):
|
|
817
917
|
"""Test that marginalization works when there are nested marginalized RVs"""
|
|
818
918
|
|
|
819
|
-
with
|
|
919
|
+
with Model() as m:
|
|
820
920
|
idx = pm.Bernoulli("idx", p=0.75)
|
|
821
921
|
sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
|
|
822
922
|
sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
|
|
823
923
|
|
|
824
|
-
|
|
924
|
+
marginal_m = marginalize(m, [idx, sub_idx])
|
|
825
925
|
|
|
826
926
|
rng = np.random.default_rng(211)
|
|
827
927
|
|
|
828
|
-
with
|
|
928
|
+
with marginal_m:
|
|
829
929
|
prior = pm.sample_prior_predictive(
|
|
830
930
|
draws=20,
|
|
831
931
|
random_seed=rng,
|
|
@@ -833,7 +933,7 @@ class TestRecoverMarginals:
|
|
|
833
933
|
)
|
|
834
934
|
idata = InferenceData(posterior=dict_to_dataset(prior))
|
|
835
935
|
|
|
836
|
-
idata =
|
|
936
|
+
idata = recover_marginals(marginal_m, idata, return_samples=True)
|
|
837
937
|
post = idata.posterior
|
|
838
938
|
assert "idx" in post
|
|
839
939
|
assert "lp_idx" in post
|