pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/deserialize.py +224 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/inference/find_map.py +62 -17
  6. pymc_extras/inference/laplace.py +10 -7
  7. pymc_extras/prior.py +1356 -0
  8. pymc_extras/statespace/core/statespace.py +191 -52
  9. pymc_extras/statespace/filters/distributions.py +15 -16
  10. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  11. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  12. pymc_extras/statespace/models/ETS.py +10 -0
  13. pymc_extras/statespace/models/SARIMAX.py +26 -5
  14. pymc_extras/statespace/models/VARMAX.py +12 -2
  15. pymc_extras/statespace/models/structural.py +18 -5
  16. pymc_extras-0.2.7.dist-info/METADATA +321 -0
  17. pymc_extras-0.2.7.dist-info/RECORD +66 -0
  18. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
  19. pymc_extras/utils/pivoted_cholesky.py +0 -69
  20. pymc_extras/version.py +0 -11
  21. pymc_extras/version.txt +0 -1
  22. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  23. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  24. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  25. tests/__init__.py +0 -13
  26. tests/distributions/__init__.py +0 -19
  27. tests/distributions/test_continuous.py +0 -185
  28. tests/distributions/test_discrete.py +0 -210
  29. tests/distributions/test_discrete_markov_chain.py +0 -258
  30. tests/distributions/test_multivariate.py +0 -304
  31. tests/distributions/test_transform.py +0 -77
  32. tests/model/__init__.py +0 -0
  33. tests/model/marginal/__init__.py +0 -0
  34. tests/model/marginal/test_distributions.py +0 -132
  35. tests/model/marginal/test_graph_analysis.py +0 -182
  36. tests/model/marginal/test_marginal_model.py +0 -967
  37. tests/model/test_model_api.py +0 -38
  38. tests/statespace/__init__.py +0 -0
  39. tests/statespace/test_ETS.py +0 -411
  40. tests/statespace/test_SARIMAX.py +0 -405
  41. tests/statespace/test_VARMAX.py +0 -184
  42. tests/statespace/test_coord_assignment.py +0 -181
  43. tests/statespace/test_distributions.py +0 -270
  44. tests/statespace/test_kalman_filter.py +0 -326
  45. tests/statespace/test_representation.py +0 -175
  46. tests/statespace/test_statespace.py +0 -872
  47. tests/statespace/test_statespace_JAX.py +0 -156
  48. tests/statespace/test_structural.py +0 -836
  49. tests/statespace/utilities/__init__.py +0 -0
  50. tests/statespace/utilities/shared_fixtures.py +0 -9
  51. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  52. tests/statespace/utilities/test_helpers.py +0 -310
  53. tests/test_blackjax_smc.py +0 -222
  54. tests/test_find_map.py +0 -103
  55. tests/test_histogram_approximation.py +0 -109
  56. tests/test_laplace.py +0 -281
  57. tests/test_linearmodel.py +0 -208
  58. tests/test_model_builder.py +0 -306
  59. tests/test_pathfinder.py +0 -297
  60. tests/test_pivoted_cholesky.py +0 -24
  61. tests/test_printing.py +0 -98
  62. tests/test_prior_from_trace.py +0 -172
  63. tests/test_splines.py +0 -77
  64. tests/utils.py +0 -0
  65. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,967 +0,0 @@
1
- import itertools
2
-
3
- from contextlib import suppress as does_not_warn
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import pymc as pm
8
- import pytensor.tensor as pt
9
- import pytest
10
-
11
- from arviz import InferenceData, dict_to_dataset
12
- from pymc import Model, draw
13
- from pymc.distributions import transforms
14
- from pymc.distributions.transforms import ordered
15
- from pymc.initial_point import make_initial_point_expression
16
- from pymc.pytensorf import constant_fold, inputvars
17
- from pymc.util import UNSET
18
- from scipy.special import log_softmax, logsumexp
19
- from scipy.stats import halfnorm, norm
20
-
21
- from pymc_extras.model.marginal.distributions import MarginalRV
22
- from pymc_extras.model.marginal.marginal_model import (
23
- marginalize,
24
- recover_marginals,
25
- unmarginalize,
26
- )
27
- from pymc_extras.utils.model_equivalence import equivalent_models
28
-
29
-
30
- def test_basic_marginalized_rv():
31
- data = [2] * 5
32
-
33
- with Model() as m:
34
- sigma = pm.HalfNormal("sigma")
35
- idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
36
- mu = pt.switch(
37
- pt.eq(idx, 0),
38
- -1.0,
39
- pt.switch(
40
- pt.eq(idx, 1),
41
- 0.0,
42
- 1.0,
43
- ),
44
- )
45
- y = pm.Normal("y", mu=mu, sigma=sigma)
46
- z = pm.Normal("z", y, observed=data)
47
-
48
- marginal_m = marginalize(m, [idx])
49
- assert isinstance(marginal_m["y"].owner.op, MarginalRV)
50
- assert ["idx"] not in [rv.name for rv in marginal_m.free_RVs]
51
-
52
- # Test forward draws
53
- y_draws, z_draws = draw(
54
- [marginal_m["y"], marginal_m["z"]],
55
- # Make sigma very small to make draws deterministic
56
- givens={marginal_m["sigma"]: 0.001},
57
- draws=1000,
58
- random_seed=54,
59
- )
60
- assert sorted(np.unique(y_draws.round()) == [-1.0, 0.0, 1.0])
61
- assert z_draws[y_draws < 0].mean() < z_draws[y_draws > 0].mean()
62
-
63
- # Test initial_point
64
- ips = make_initial_point_expression(
65
- # Use basic_RVs to include the observed RV
66
- free_rvs=marginal_m.basic_RVs,
67
- rvs_to_transforms=marginal_m.rvs_to_transforms,
68
- initval_strategies={},
69
- )
70
- # After simplification, we should have only constants in the graph (expect alloc which isn't constant folded):
71
- ip_sigma, ip_y, ip_z = constant_fold(ips)
72
- np.testing.assert_allclose(ip_sigma, 1.0)
73
- np.testing.assert_allclose(ip_y, 1.0)
74
- np.testing.assert_allclose(ip_z, np.full((5,), 1.0))
75
-
76
- marginal_ip = marginal_m.initial_point()
77
- expected_ip = m.initial_point()
78
- expected_ip.pop("idx")
79
- assert marginal_ip == expected_ip
80
-
81
- # Test logp
82
- with pm.Model() as ref_m:
83
- sigma = pm.HalfNormal("sigma")
84
- y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
85
- z = pm.Normal("z", y, observed=data)
86
-
87
- np.testing.assert_almost_equal(
88
- marginal_m.compile_logp()(marginal_ip),
89
- ref_m.compile_logp()(marginal_ip),
90
- )
91
- np.testing.assert_almost_equal(
92
- marginal_m.compile_dlogp([marginal_m["y"]])(marginal_ip),
93
- ref_m.compile_dlogp([ref_m["y"]])(marginal_ip),
94
- )
95
-
96
-
97
- def test_one_to_one_marginalized_rvs():
98
- """Test case with multiple, independent marginalized RVs."""
99
- with Model() as m:
100
- sigma = pm.HalfNormal("sigma")
101
- idx1 = pm.Bernoulli("idx1", p=0.75)
102
- x = pm.Normal("x", mu=idx1, sigma=sigma)
103
- idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,))
104
- y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,))
105
-
106
- marginal_m = marginalize(m, [idx1, idx2])
107
- assert isinstance(marginal_m["x"].owner.op, MarginalRV)
108
- assert isinstance(marginal_m["y"].owner.op, MarginalRV)
109
- assert marginal_m["x"].owner is not marginal_m["y"].owner
110
-
111
- with pm.Model() as ref_m:
112
- sigma = pm.HalfNormal("sigma")
113
- x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma)
114
- y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,))
115
-
116
- # Test logp
117
- test_point = ref_m.initial_point()
118
- x_logp, y_logp = marginal_m.compile_logp(vars=[marginal_m["x"], marginal_m["y"]], sum=False)(
119
- test_point
120
- )
121
- x_ref_log, y_ref_logp = ref_m.compile_logp(vars=[ref_m["x"], ref_m["y"]], sum=False)(test_point)
122
- np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum())
123
- np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
124
-
125
-
126
- def test_one_to_many_marginalized_rvs():
127
- """Test that marginalization works when there is more than one dependent RV"""
128
- with Model() as m:
129
- sigma = pm.HalfNormal("sigma")
130
- idx = pm.Bernoulli("idx", p=0.75)
131
- x = pm.Normal("x", mu=idx, sigma=sigma)
132
- y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,))
133
-
134
- marginal_m = marginalize(m, [idx])
135
-
136
- marginal_x = marginal_m["x"]
137
- marginal_y = marginal_m["y"]
138
- assert isinstance(marginal_x.owner.op, MarginalRV)
139
- assert isinstance(marginal_y.owner.op, MarginalRV)
140
- assert marginal_x.owner is marginal_y.owner
141
-
142
- ref_logp_x_y_fn = m.compile_logp([idx, x, y])
143
- tp = marginal_m.initial_point()
144
- ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)])
145
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
146
- logp_x_y = marginal_m.compile_logp([marginal_x, marginal_y])(tp)
147
- np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
148
-
149
-
150
- def test_one_to_many_unaligned_marginalized_rvs():
151
- """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
152
-
153
- def build_model(build_batched: bool):
154
- with Model() as m:
155
- if build_batched:
156
- idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
157
- else:
158
- idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)]
159
- idx = pt.stack(idxs, axis=0).reshape((3, 2))
160
-
161
- x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1))
162
- y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2))
163
-
164
- return m
165
-
166
- marginal_m = marginalize(build_model(build_batched=True), ["idx"])
167
- ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(6)])
168
-
169
- test_point = marginal_m.initial_point()
170
-
171
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
172
- np.testing.assert_allclose(
173
- marginal_m.compile_logp()(test_point),
174
- ref_m.compile_logp()(test_point),
175
- )
176
-
177
-
178
- def test_many_to_one_marginalized_rvs():
179
- """Test when random variables depend on multiple marginalized variables"""
180
- with Model() as m:
181
- x = pm.Bernoulli("x", 0.1)
182
- y = pm.Bernoulli("y", 0.3)
183
- z = pm.DiracDelta("z", c=x + y)
184
-
185
- logp_fn = marginalize(m, [x, y]).compile_logp()
186
-
187
- np.testing.assert_allclose(np.exp(logp_fn({"z": 0})), 0.9 * 0.7)
188
- np.testing.assert_allclose(np.exp(logp_fn({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
189
- np.testing.assert_allclose(np.exp(logp_fn({"z": 2})), 0.1 * 0.3)
190
-
191
-
192
- @pytest.mark.parametrize("batched", (False, "left", "right"))
193
- def test_nested_marginalized_rvs(batched):
194
- """Test that marginalization works when there are nested marginalized RVs"""
195
-
196
- def build_model(build_batched: bool) -> Model:
197
- idx_shape = (3,) if build_batched else ()
198
- sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
199
-
200
- with Model() as m:
201
- sigma = pm.HalfNormal("sigma")
202
-
203
- idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
204
- dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
205
-
206
- sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95)
207
- if build_batched and batched == "right":
208
- sub_idx_p = sub_idx_p[..., None]
209
- dep = dep[..., None]
210
- sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape)
211
- sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma)
212
-
213
- return m
214
-
215
- marginal_m = marginalize(build_model(build_batched=batched), ["idx", "sub_idx"])
216
- assert all(rv.name not in ("idx", "sub_idx") for rv in marginal_m.free_RVs)
217
-
218
- # Test forward draws and initial_point, shouldn't depend on batching, so we only test one case
219
- if not batched:
220
- # Test forward draws
221
- dep_draws, sub_dep_draws = draw(
222
- [marginal_m["dep"], marginal_m["sub_dep"]],
223
- # Make sigma very small to make draws deterministic
224
- givens={marginal_m["sigma"]: 0.001},
225
- draws=1000,
226
- random_seed=214,
227
- )
228
- assert sorted(np.unique(dep_draws.round()) == [-1000.0, 1000.0])
229
- assert sorted(np.unique(sub_dep_draws.round()) == [-1000.0, -900.0, 1000.0, 1100.0])
230
-
231
- # Test initial_point
232
- ips = make_initial_point_expression(
233
- free_rvs=marginal_m.free_RVs,
234
- rvs_to_transforms=marginal_m.rvs_to_transforms,
235
- initval_strategies={},
236
- )
237
- # After simplification, we should have only constants in the graph
238
- ip_sigma, ip_dep, ip_sub_dep = constant_fold(ips)
239
- np.testing.assert_allclose(ip_sigma, 1.0)
240
- np.testing.assert_allclose(ip_dep, 1000.0)
241
- np.testing.assert_allclose(ip_sub_dep, np.full((5,), 1100.0))
242
-
243
- # Test logp
244
- ref_m = build_model(build_batched=False)
245
- ref_logp_fn = ref_m.compile_logp(
246
- vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]]
247
- )
248
-
249
- test_point = ref_m.initial_point()
250
- test_point["dep"] = np.full_like(test_point["dep"], 1000)
251
- test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
252
- ref_logp = logsumexp(
253
- [
254
- ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}})
255
- for idx in (0, 1)
256
- for sub_idxs in itertools.product((0, 1), repeat=5)
257
- ]
258
- )
259
- if batched:
260
- ref_logp *= 3
261
-
262
- test_point = marginal_m.initial_point()
263
- test_point["dep"] = np.full_like(test_point["dep"], 1000)
264
- test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
265
-
266
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
267
- logp = marginal_m.compile_logp(vars=[marginal_m["dep"], marginal_m["sub_dep"]])(test_point)
268
-
269
- np.testing.assert_almost_equal(logp, ref_logp)
270
-
271
-
272
- def test_interdependent_rvs():
273
- """Test Marginalization when dependent RVs are interdependent."""
274
- with Model() as m:
275
- idx = pm.Bernoulli("idx", p=0.75)
276
- x = pm.Normal("x", mu=idx * 2, sigma=1e-3)
277
- # Y depends on both x and idx
278
- y = pm.Normal("y", mu=x * idx * 2, sigma=1e-3)
279
-
280
- marginal_m = marginalize(m, "idx")
281
-
282
- marginal_x = marginal_m["x"]
283
- marginal_y = marginal_m["y"]
284
- assert isinstance(marginal_x.owner.op, MarginalRV)
285
- assert isinstance(marginal_y.owner.op, MarginalRV)
286
- assert marginal_x.owner is marginal_y.owner
287
-
288
- # Test forward draws
289
- x_draws, y_draws = draw([marginal_x, marginal_y], draws=1000, random_seed=54)
290
- assert sorted(np.unique(x_draws.round())) == [0, 2]
291
- assert sorted(np.unique(y_draws.round())) == [0, 4]
292
- assert np.unique(y_draws[x_draws < 1].round()) == [0]
293
- assert np.unique(y_draws[x_draws > 1].round()) == [4]
294
-
295
- # Test initial_point
296
- ips = make_initial_point_expression(
297
- free_rvs=marginal_m.free_RVs,
298
- rvs_to_transforms={},
299
- initval_strategies={},
300
- )
301
- # After simplification, we should have only constants in the graph
302
- ip_x, ip_y = constant_fold(ips)
303
- np.testing.assert_allclose(ip_x, 2.0)
304
- np.testing.assert_allclose(ip_y, 4.0)
305
-
306
- # Test custom initval strategy
307
- ips = make_initial_point_expression(
308
- # Test that order does not matter
309
- free_rvs=marginal_m.free_RVs[::-1],
310
- rvs_to_transforms={},
311
- initval_strategies={marginal_x: pt.constant(5.0)},
312
- )
313
- ip_y, ip_x = constant_fold(ips)
314
- np.testing.assert_allclose(ip_x, 5.0)
315
- np.testing.assert_allclose(ip_y, 10.0)
316
-
317
- # Test logp
318
- test_point = marginal_m.initial_point()
319
- ref_logp_fn = m.compile_logp([m["idx"], m["x"], m["y"]])
320
- ref_logp = logsumexp([ref_logp_fn({**test_point, **{"idx": idx}}) for idx in (0, 1)])
321
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
322
- logp = marginal_m.compile_logp([marginal_m["x"], marginal_m["y"]])(test_point)
323
- np.testing.assert_almost_equal(logp, ref_logp)
324
-
325
-
326
- @pytest.mark.parametrize("advanced_indexing", (False, True))
327
- def test_marginalized_index_as_key(advanced_indexing):
328
- """Test we can marginalize graphs where indexing is used as a mapping."""
329
-
330
- w = [0.1, 0.3, 0.6]
331
- mu = pt.as_tensor([-1, 0, 1])
332
-
333
- if advanced_indexing:
334
- y_val = pt.as_tensor([[-1, -1], [0, 1]])
335
- shape = (2, 2)
336
- else:
337
- y_val = -1
338
- shape = ()
339
-
340
- with Model() as m:
341
- x = pm.Categorical("x", p=w, shape=shape)
342
- y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
343
-
344
- marginal_m = marginalize(m, x)
345
-
346
- marginal_logp = marginal_m.compile_logp(sum=False)({})[0]
347
- ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
348
-
349
- np.testing.assert_allclose(marginal_logp, ref_logp)
350
-
351
-
352
- def test_marginalized_index_as_value_and_key():
353
- """Test we can marginalize graphs were marginalized_rv is indexed."""
354
-
355
- def build_model(build_batched: bool) -> Model:
356
- with Model() as m:
357
- if build_batched:
358
- latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
359
- else:
360
- latent_state = pm.math.stack(
361
- [pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)]
362
- )
363
- # latent state is used as the indexed variable
364
- latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0])
365
- picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6])
366
- # picked intensity is used as the indexing variable
367
- pm.Normal(
368
- "intensity",
369
- mu=latent_intensities[:, picked_intensity],
370
- observed=[0.5, 1.5, 5.0, 15.0],
371
- )
372
- return m
373
-
374
- # We compare with the equivalent but less efficient batched model
375
- m = build_model(build_batched=True)
376
- ref_m = build_model(build_batched=False)
377
-
378
- m = marginalize(m, ["latent_state"])
379
- ref_m = marginalize(ref_m, [f"latent_state_{i}" for i in range(4)])
380
- test_point = {"picked_intensity": 1}
381
- np.testing.assert_allclose(
382
- m.compile_logp()(test_point),
383
- ref_m.compile_logp()(test_point),
384
- )
385
-
386
- m = marginalize(m, ["picked_intensity"])
387
- ref_m = marginalize(ref_m, ["picked_intensity"])
388
- test_point = {}
389
- np.testing.assert_allclose(
390
- m.compile_logp()(test_point),
391
- ref_m.compile_logp()(test_point),
392
- )
393
-
394
-
395
- class TestNotSupportedMixedDims:
396
- """Test lack of support for models where batch dims of marginalized variables are mixed."""
397
-
398
- def test_mixed_dims_via_transposed_dot(self):
399
- with Model() as m:
400
- idx = pm.Bernoulli("idx", p=0.7, shape=2)
401
- y = pm.Normal("y", mu=idx @ idx.T)
402
-
403
- with pytest.raises(NotImplementedError):
404
- marginalize(m, idx)
405
-
406
- def test_mixed_dims_via_indexing(self):
407
- mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]])
408
-
409
- with Model() as m:
410
- idx = pm.Bernoulli("idx", p=0.7, shape=2)
411
- y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx])
412
- with pytest.raises(NotImplementedError):
413
- marginalize(m, idx)
414
-
415
- with Model() as m:
416
- idx = pm.Bernoulli("idx", p=0.7, shape=2)
417
- y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx])
418
- with pytest.raises(NotImplementedError):
419
- marginalize(m, idx)
420
-
421
- with Model() as m:
422
- idx = pm.Bernoulli("idx", p=0.7, shape=2)
423
- mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable(
424
- mean[None, :][:, idx], 0
425
- )
426
- y = pm.Normal("y", mu=mu)
427
- with pytest.raises(NotImplementedError):
428
- marginalize(m, idx)
429
-
430
- with Model() as m:
431
- idx = pm.Bernoulli("idx", p=0.7, shape=2)
432
- y = pm.Normal("y", mu=idx[0] + idx[1])
433
- with pytest.raises(NotImplementedError):
434
- marginalize(m, idx)
435
-
436
- def test_mixed_dims_via_vector_indexing(self):
437
- with Model() as m:
438
- idx = pm.Bernoulli("idx", p=0.7, shape=2)
439
- y = pm.Normal("y", mu=idx[[0, 1, 0, 0]])
440
- with pytest.raises(NotImplementedError):
441
- marginalize(m, idx)
442
-
443
- with Model() as m:
444
- idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2))
445
- y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)])
446
- with pytest.raises(NotImplementedError):
447
- marginalize(m, idx)
448
-
449
- def test_mixed_dims_via_support_dimension(self):
450
- with Model() as m:
451
- x = pm.Bernoulli("x", p=0.7, shape=3)
452
- y = pm.Dirichlet("y", a=x * 10 + 1)
453
- with pytest.raises(NotImplementedError):
454
- marginalize(m, x)
455
-
456
- def test_mixed_dims_via_nested_marginalization(self):
457
- with Model() as m:
458
- x = pm.Bernoulli("x", p=0.7, shape=(3,))
459
- y = pm.Bernoulli("y", p=0.7, shape=(2,))
460
- z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
461
-
462
- with pytest.raises(NotImplementedError):
463
- marginalize(m, [x, y])
464
-
465
-
466
- def test_marginalized_deterministic_and_potential():
467
- rng = np.random.default_rng(299)
468
-
469
- with Model() as m:
470
- x = pm.Bernoulli("x", p=0.7)
471
- y = pm.Normal("y", x)
472
- z = pm.Normal("z", x)
473
- det = pm.Deterministic("det", y + z)
474
- pot = pm.Potential("pot", y + z + 1)
475
-
476
- marginal_m = marginalize(m, [x])
477
-
478
- y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng)
479
- np.testing.assert_almost_equal(y_draw + z_draw, det_draw)
480
- np.testing.assert_almost_equal(det_draw, pot_draw - 1)
481
-
482
- y_value = marginal_m.rvs_to_values[marginal_m["y"]]
483
- z_value = marginal_m.rvs_to_values[marginal_m["z"]]
484
- det_value, pot_value = marginal_m.replace_rvs_by_values([marginal_m["det"], marginal_m["pot"]])
485
- assert set(inputvars([det_value, pot_value])) == {y_value, z_value}
486
- assert det_value.eval({y_value: 2, z_value: 5}) == 7
487
- assert pot_value.eval({y_value: 2, z_value: 5}) == 8
488
-
489
-
490
- def test_not_supported_marginalized_deterministic_and_potential():
491
- with Model() as m:
492
- x = pm.Bernoulli("x", p=0.7)
493
- y = pm.Normal("y", x)
494
- det = pm.Deterministic("det", x + y)
495
-
496
- with pytest.raises(
497
- NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det"
498
- ):
499
- marginalize(m, [x])
500
-
501
- with Model() as m:
502
- x = pm.Bernoulli("x", p=0.7)
503
- y = pm.Normal("y", x)
504
- pot = pm.Potential("pot", x + y)
505
-
506
- with pytest.raises(
507
- NotImplementedError, match="Cannot marginalize x due to dependent Potential pot"
508
- ):
509
- marginalize(m, [x])
510
-
511
-
512
- @pytest.mark.parametrize(
513
- "transform, expected_warning",
514
- (
515
- (None, does_not_warn()),
516
- (UNSET, does_not_warn()),
517
- (transforms.log, does_not_warn()),
518
- (transforms.Chain([transforms.logodds, transforms.log]), does_not_warn()),
519
- (
520
- transforms.Interval(0, 2),
521
- pytest.warns(
522
- UserWarning, match="which depends on the marginalized idx may no longer work"
523
- ),
524
- ),
525
- (
526
- transforms.Chain([transforms.log, transforms.Interval(-1, 1)]),
527
- pytest.warns(
528
- UserWarning, match="which depends on the marginalized idx may no longer work"
529
- ),
530
- ),
531
- ),
532
- )
533
- def test_marginalized_transforms(transform, expected_warning):
534
- w = [0.1, 0.3, 0.6]
535
- data = [0, 5, 10]
536
- initval = 0.7 # Value that will be negative on the unconstrained space
537
-
538
- with pm.Model() as m_ref:
539
- sigma = pm.Mixture(
540
- "sigma",
541
- w=w,
542
- comp_dists=pm.HalfNormal.dist([1, 2, 3]),
543
- initval=initval,
544
- default_transform=transform,
545
- )
546
- y = pm.Normal("y", 0, sigma, observed=data)
547
-
548
- with Model() as m:
549
- idx = pm.Categorical("idx", p=w)
550
- sigma = pm.HalfNormal(
551
- "sigma",
552
- pt.switch(
553
- pt.eq(idx, 0),
554
- 1,
555
- pt.switch(
556
- pt.eq(idx, 1),
557
- 2,
558
- 3,
559
- ),
560
- ),
561
- default_transform=transform,
562
- )
563
- y = pm.Normal("y", 0, sigma, observed=data)
564
-
565
- with expected_warning:
566
- marginal_m = marginalize(m, [idx])
567
-
568
- marginal_m.set_initval(marginal_m["sigma"], initval)
569
- ip = marginal_m.initial_point()
570
- if transform is not None:
571
- if transform is UNSET:
572
- transform_name = "log"
573
- else:
574
- transform_name = transform.name
575
- assert -np.inf < ip[f"sigma_{transform_name}__"] < 0.0
576
- np.testing.assert_allclose(marginal_m.compile_logp()(ip), m_ref.compile_logp()(ip))
577
-
578
-
579
- def test_data_container():
580
- """Test that MarginalModel can handle Data containers."""
581
- with Model(coords={"obs": [0]}) as m:
582
- x = pm.Data("x", 2.5)
583
- idx = pm.Bernoulli("idx", p=0.7, dims="obs")
584
- y = pm.Normal("y", idx * x, dims="obs")
585
-
586
- marginal_m = marginalize(m, [idx])
587
-
588
- logp_fn = marginal_m.compile_logp()
589
-
590
- with pm.Model(coords={"obs": [0]}) as m_ref:
591
- x = pm.Data("x", 2.5)
592
- y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs")
593
-
594
- ref_logp_fn = m_ref.compile_logp()
595
-
596
- for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1):
597
- for m in (marginal_m, m_ref):
598
- m.set_dim("obs", new_length=i, coord_values=tuple(range(i)))
599
- pm.set_data({"x": x_val}, model=m)
600
-
601
- ip = marginal_m.initial_point()
602
- np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))
603
-
604
-
605
- def test_mutable_indexing_jax_backend():
606
- pytest.importorskip("jax")
607
- from pymc.sampling.jax import get_jaxified_logp
608
-
609
- with Model() as model:
610
- data = pm.Data("data", np.zeros(10))
611
-
612
- cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
613
- cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5))
614
-
615
- is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
616
- pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
617
- marginal_model = marginalize(model, ["is_outlier"])
618
- get_jaxified_logp(marginal_model)
619
-
620
-
621
- class TestFullModels:
622
- @pytest.fixture
623
- def disaster_model(self):
624
- # fmt: off
625
- disaster_data = pd.Series(
626
- [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
627
- 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
628
- 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
629
- 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
630
- 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
631
- 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
632
- 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
633
- )
634
- # fmt: on
635
- years = np.arange(1851, 1962)
636
-
637
- with Model() as disaster_model:
638
- switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
639
- early_rate = pm.Exponential("early_rate", 1.0)
640
- late_rate = pm.Exponential("late_rate", 1.0)
641
- rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
642
- with pytest.warns(Warning):
643
- disasters = pm.Poisson("disasters", rate, observed=disaster_data)
644
-
645
- return disaster_model, years
646
-
647
- def test_change_point_model(self, disaster_model):
648
- m, years = disaster_model
649
-
650
- ip = m.initial_point()
651
- ip["late_rate_log__"] += 1.0 # Make early and endpoint ip different
652
-
653
- ip.pop("switchpoint")
654
- ref_logp_fn = m.compile_logp(
655
- [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]]
656
- )
657
- ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years])
658
-
659
- marginal_m = marginalize(m, m["switchpoint"])
660
-
661
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
662
- marginal_m_logp = marginal_m.compile_logp(
663
- [marginal_m["disasters_observed"], marginal_m["disasters_unobserved"]]
664
- )(ip)
665
- np.testing.assert_almost_equal(marginal_m_logp, ref_logp)
666
-
667
- @pytest.mark.slow
668
- def test_change_point_model_sampling(self, disaster_model):
669
- m, _ = disaster_model
670
-
671
- rng = np.random.default_rng(211)
672
-
673
- with m:
674
- before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
675
- sample=("draw", "chain")
676
- )
677
-
678
- marginal_m = marginalize(m, "switchpoint")
679
-
680
- with marginal_m:
681
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
682
- after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
683
- sample=("draw", "chain")
684
- )
685
-
686
- np.testing.assert_allclose(
687
- before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2
688
- )
689
- np.testing.assert_allclose(
690
- before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2
691
- )
692
- np.testing.assert_allclose(
693
- before_marg["disasters_unobserved"].mean(),
694
- after_marg["disasters_unobserved"].mean(),
695
- rtol=1e-2,
696
- )
697
-
698
- @pytest.mark.parametrize("univariate", (True, False))
699
- def test_vector_univariate_mixture(self, univariate):
700
- with Model() as m:
701
- idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
702
-
703
- def dist(idx, size):
704
- return pm.math.switch(
705
- pm.math.eq(idx, 0),
706
- pm.Normal.dist([-10, -10], 1),
707
- pm.Normal.dist([10, 10], 1),
708
- )
709
-
710
- pm.CustomDist("norm", idx, dist=dist)
711
-
712
- marginal_m = marginalize(m, idx)
713
- logp_fn = marginal_m.compile_logp()
714
-
715
- if univariate:
716
- with pm.Model() as ref_m:
717
- pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
718
- else:
719
- with pm.Model() as ref_m:
720
- pm.Mixture(
721
- "norm",
722
- w=[0.5, 0.5],
723
- comp_dists=[
724
- pm.MvNormal.dist([-10, -10], np.eye(2)),
725
- pm.MvNormal.dist([10, 10], np.eye(2)),
726
- ],
727
- shape=(2,),
728
- )
729
- ref_logp_fn = ref_m.compile_logp()
730
-
731
- for test_value in (
732
- [-10, -10],
733
- [10, 10],
734
- [-10, 10],
735
- [-10, 10],
736
- ):
737
- pt = {"norm": test_value}
738
- np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
739
-
740
- def test_k_censored_clusters_model(self):
741
- data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
742
- nobs = data.shape[0]
743
- n_clusters = 5
744
-
745
- def build_model(build_batched: bool) -> Model:
746
- coords = {
747
- "cluster": range(n_clusters),
748
- "ndim": ("x", "y"),
749
- "obs": range(nobs),
750
- }
751
- with Model(coords=coords) as m:
752
- if build_batched:
753
- idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
754
- else:
755
- idx = pm.math.stack(
756
- [
757
- pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters)
758
- for i in range(nobs)
759
- ]
760
- )
761
-
762
- mu_x = pm.Normal(
763
- "mu_x",
764
- dims=["cluster"],
765
- transform=ordered,
766
- )
767
- mu_y = pm.Normal("mu_y", dims=["cluster"])
768
- mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
769
- mu_indexed = mu[idx, :]
770
-
771
- sigma = pm.HalfNormal("sigma")
772
-
773
- y = pm.Censored(
774
- "y",
775
- dist=pm.Normal.dist(mu_indexed, sigma),
776
- lower=-3,
777
- upper=3,
778
- observed=data,
779
- dims=["obs", "ndim"],
780
- )
781
-
782
- return m
783
-
784
- m = marginalize(build_model(build_batched=True), "idx")
785
- m.set_initval(m["mu_x"], np.linspace(-1, 1, n_clusters))
786
-
787
- ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(nobs)])
788
- test_point = m.initial_point()
789
- np.testing.assert_almost_equal(
790
- m.compile_logp()(test_point),
791
- ref_m.compile_logp()(test_point),
792
- )
793
-
794
-
795
- def test_unmarginalize():
796
- with pm.Model() as m:
797
- idx = pm.Bernoulli("idx", p=0.5)
798
- sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx])
799
- x = pm.Normal("x", mu=(idx + sub_idx) - 1)
800
-
801
- marginal_m = marginalize(m, [idx, sub_idx])
802
- assert not equivalent_models(marginal_m, m)
803
-
804
- unmarginal_m = unmarginalize(marginal_m)
805
- assert equivalent_models(unmarginal_m, m)
806
-
807
- unmarginal_idx_explicit = unmarginalize(marginal_m, ("idx", "sub_idx"))
808
- assert equivalent_models(unmarginal_idx_explicit, m)
809
-
810
- # Test partial unmarginalize
811
- unmarginal_idx = unmarginalize(marginal_m, "idx")
812
- assert equivalent_models(unmarginal_idx, marginalize(m, "sub_idx"))
813
-
814
- unmarginal_sub_idx = unmarginalize(marginal_m, "sub_idx")
815
- assert equivalent_models(unmarginal_sub_idx, marginalize(m, "idx"))
816
-
817
-
818
- class TestRecoverMarginals:
819
- def test_basic(self):
820
- with Model() as m:
821
- sigma = pm.HalfNormal("sigma")
822
- p = np.array([0.5, 0.2, 0.3])
823
- k = pm.Categorical("k", p=p)
824
- mu = np.array([-3.0, 0.0, 3.0])
825
- mu_ = pt.as_tensor_variable(mu)
826
- y = pm.Normal("y", mu=mu_[k], sigma=sigma)
827
-
828
- marginal_m = marginalize(m, [k])
829
-
830
- rng = np.random.default_rng(211)
831
-
832
- with marginal_m:
833
- prior = pm.sample_prior_predictive(
834
- draws=20,
835
- random_seed=rng,
836
- return_inferencedata=False,
837
- )
838
- idata = InferenceData(posterior=dict_to_dataset(prior))
839
-
840
- idata = recover_marginals(marginal_m, idata, return_samples=True)
841
- post = idata.posterior
842
- assert "k" in post
843
- assert "lp_k" in post
844
- assert post.k.shape == post.y.shape
845
- assert post.lp_k.shape == (*post.k.shape, len(p))
846
-
847
- def true_logp(y, sigma):
848
- y = y.repeat(len(p)).reshape(len(y), -1)
849
- sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
850
- return log_softmax(
851
- np.log(p)
852
- + norm.logpdf(y, loc=mu, scale=sigma)
853
- + halfnorm.logpdf(sigma)
854
- + np.log(sigma),
855
- axis=1,
856
- )
857
-
858
- np.testing.assert_almost_equal(
859
- true_logp(post.y.values.flatten(), post.sigma.values.flatten()),
860
- post.lp_k[0].values,
861
- )
862
- np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0)
863
-
864
- def test_coords(self):
865
- """Test if coords can be recovered with marginalized value had it originally"""
866
- with Model(coords={"year": [1990, 1991, 1992]}) as m:
867
- sigma = pm.HalfNormal("sigma")
868
- idx = pm.Bernoulli("idx", p=0.75, dims="year")
869
- x = pm.Normal("x", mu=idx, sigma=sigma, dims="year")
870
-
871
- marginal_m = marginalize(m, [idx])
872
- rng = np.random.default_rng(211)
873
-
874
- with marginal_m:
875
- prior = pm.sample_prior_predictive(
876
- draws=20,
877
- random_seed=rng,
878
- return_inferencedata=False,
879
- )
880
- idata = InferenceData(
881
- posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
882
- )
883
-
884
- idata = recover_marginals(marginal_m, idata, return_samples=True)
885
- post = idata.posterior
886
- assert post.idx.dims == ("chain", "draw", "year")
887
- assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
888
-
889
- def test_batched(self):
890
- """Test that marginalization works for batched random variables"""
891
- with Model() as m:
892
- sigma = pm.HalfNormal("sigma")
893
- idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
894
- y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
895
-
896
- marginal_m = marginalize(m, [idx])
897
-
898
- rng = np.random.default_rng(211)
899
-
900
- with marginal_m:
901
- prior = pm.sample_prior_predictive(
902
- draws=20,
903
- random_seed=rng,
904
- return_inferencedata=False,
905
- )
906
- idata = InferenceData(
907
- posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
908
- )
909
-
910
- idata = recover_marginals(marginal_m, idata, return_samples=True)
911
- post = idata.posterior
912
- assert post["y"].shape == (1, 20, 2, 3)
913
- assert post["idx"].shape == (1, 20, 3, 2)
914
- assert post["lp_idx"].shape == (1, 20, 3, 2, 2)
915
-
916
- def test_nested(self):
917
- """Test that marginalization works when there are nested marginalized RVs"""
918
-
919
- with Model() as m:
920
- idx = pm.Bernoulli("idx", p=0.75)
921
- sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
922
- sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
923
-
924
- marginal_m = marginalize(m, [idx, sub_idx])
925
-
926
- rng = np.random.default_rng(211)
927
-
928
- with marginal_m:
929
- prior = pm.sample_prior_predictive(
930
- draws=20,
931
- random_seed=rng,
932
- return_inferencedata=False,
933
- )
934
- idata = InferenceData(posterior=dict_to_dataset(prior))
935
-
936
- idata = recover_marginals(marginal_m, idata, return_samples=True)
937
- post = idata.posterior
938
- assert "idx" in post
939
- assert "lp_idx" in post
940
- assert post.idx.shape == post.y.shape
941
- assert post.lp_idx.shape == (*post.idx.shape, 2)
942
- assert "sub_idx" in post
943
- assert "lp_sub_idx" in post
944
- assert post.sub_idx.shape == post.y.shape
945
- assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2)
946
-
947
- def true_idx_logp(y):
948
- idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
949
- idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
950
- return log_softmax(np.stack([idx_0, idx_1]).T, axis=1)
951
-
952
- np.testing.assert_almost_equal(
953
- true_idx_logp(post.y.values.flatten()),
954
- post.lp_idx[0].values,
955
- )
956
-
957
- def true_sub_idx_logp(y):
958
- sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1))
959
- sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
960
- return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1)
961
-
962
- np.testing.assert_almost_equal(
963
- true_sub_idx_logp(post.y.values.flatten()),
964
- post.lp_sub_idx[0].values,
965
- )
966
- np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0)
967
- np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0)