pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,867 @@
1
+ import itertools
2
+
3
+ from contextlib import suppress as does_not_warn
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pymc as pm
8
+ import pytensor.tensor as pt
9
+ import pytest
10
+
11
+ from arviz import InferenceData, dict_to_dataset
12
+ from pymc.distributions import transforms
13
+ from pymc.distributions.transforms import ordered
14
+ from pymc.model.fgraph import fgraph_from_model
15
+ from pymc.pytensorf import inputvars
16
+ from pymc.util import UNSET
17
+ from scipy.special import log_softmax, logsumexp
18
+ from scipy.stats import halfnorm, norm
19
+
20
+ from pymc_extras.model.marginal.marginal_model import (
21
+ MarginalModel,
22
+ marginalize,
23
+ )
24
+ from tests.utils import equal_computations_up_to_root
25
+
26
+
27
+ def test_basic_marginalized_rv():
28
+ data = [2] * 5
29
+
30
+ with MarginalModel() as m:
31
+ sigma = pm.HalfNormal("sigma")
32
+ idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
33
+ mu = pt.switch(
34
+ pt.eq(idx, 0),
35
+ -1.0,
36
+ pt.switch(
37
+ pt.eq(idx, 1),
38
+ 0.0,
39
+ 1.0,
40
+ ),
41
+ )
42
+ y = pm.Normal("y", mu=mu, sigma=sigma)
43
+ z = pm.Normal("z", y, observed=data)
44
+
45
+ m.marginalize([idx])
46
+ assert idx not in m.free_RVs
47
+ assert [rv.name for rv in m.marginalized_rvs] == ["idx"]
48
+
49
+ # Test logp
50
+ with pm.Model() as m_ref:
51
+ sigma = pm.HalfNormal("sigma")
52
+ y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
53
+ z = pm.Normal("z", y, observed=data)
54
+
55
+ test_point = m_ref.initial_point()
56
+ ref_logp = m_ref.compile_logp()(test_point)
57
+ ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point)
58
+
59
+ # Assert we can marginalize and unmarginalize internally non-destructively
60
+ for i in range(3):
61
+ np.testing.assert_almost_equal(
62
+ m.compile_logp()(test_point),
63
+ ref_logp,
64
+ )
65
+ np.testing.assert_almost_equal(
66
+ m.compile_dlogp([m["y"]])(test_point),
67
+ ref_dlogp,
68
+ )
69
+
70
+
71
+ def test_one_to_one_marginalized_rvs():
72
+ """Test case with multiple, independent marginalized RVs."""
73
+ with MarginalModel() as m:
74
+ sigma = pm.HalfNormal("sigma")
75
+ idx1 = pm.Bernoulli("idx1", p=0.75)
76
+ x = pm.Normal("x", mu=idx1, sigma=sigma)
77
+ idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,))
78
+ y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,))
79
+
80
+ m.marginalize([idx1, idx2])
81
+ m["x"].owner is not m["y"].owner
82
+ _m = m.clone()._marginalize()
83
+ _m["x"].owner is not _m["y"].owner
84
+
85
+ with pm.Model() as m_ref:
86
+ sigma = pm.HalfNormal("sigma")
87
+ x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma)
88
+ y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,))
89
+
90
+ # Test logp
91
+ test_point = m_ref.initial_point()
92
+ x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point)
93
+ x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point)
94
+ np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum())
95
+ np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
96
+
97
+
98
+ def test_one_to_many_marginalized_rvs():
99
+ """Test that marginalization works when there is more than one dependent RV"""
100
+ with MarginalModel() as m:
101
+ sigma = pm.HalfNormal("sigma")
102
+ idx = pm.Bernoulli("idx", p=0.75)
103
+ x = pm.Normal("x", mu=idx, sigma=sigma)
104
+ y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,))
105
+
106
+ ref_logp_x_y_fn = m.compile_logp([idx, x, y])
107
+
108
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
109
+ m.marginalize([idx])
110
+
111
+ m["x"].owner is not m["y"].owner
112
+ _m = m.clone()._marginalize()
113
+ _m["x"].owner is _m["y"].owner
114
+
115
+ tp = m.initial_point()
116
+ ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)])
117
+ logp_x_y = m.compile_logp([x, y])(tp)
118
+ np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
119
+
120
+
121
+ def test_one_to_many_unaligned_marginalized_rvs():
122
+ """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
123
+
124
+ def build_model(build_batched: bool):
125
+ with MarginalModel() as m:
126
+ if build_batched:
127
+ idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
128
+ else:
129
+ idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)]
130
+ idx = pt.stack(idxs, axis=0).reshape((3, 2))
131
+
132
+ x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1))
133
+ y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2))
134
+
135
+ return m
136
+
137
+ m = build_model(build_batched=True)
138
+ ref_m = build_model(build_batched=False)
139
+
140
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
141
+ m.marginalize(["idx"])
142
+ ref_m.marginalize([f"idx_{i}" for i in range(6)])
143
+
144
+ test_point = m.initial_point()
145
+ np.testing.assert_allclose(
146
+ m.compile_logp()(test_point),
147
+ ref_m.compile_logp()(test_point),
148
+ )
149
+
150
+
151
+ def test_many_to_one_marginalized_rvs():
152
+ """Test when random variables depend on multiple marginalized variables"""
153
+ with MarginalModel() as m:
154
+ x = pm.Bernoulli("x", 0.1)
155
+ y = pm.Bernoulli("y", 0.3)
156
+ z = pm.DiracDelta("z", c=x + y)
157
+
158
+ m.marginalize([x, y])
159
+ logp = m.compile_logp()
160
+
161
+ np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7)
162
+ np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
163
+ np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3)
164
+
165
+
166
+ @pytest.mark.parametrize("batched", (False, "left", "right"))
167
+ def test_nested_marginalized_rvs(batched):
168
+ """Test that marginalization works when there are nested marginalized RVs"""
169
+
170
+ def build_model(build_batched: bool) -> MarginalModel:
171
+ idx_shape = (3,) if build_batched else ()
172
+ sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
173
+
174
+ with MarginalModel() as m:
175
+ sigma = pm.HalfNormal("sigma")
176
+
177
+ idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
178
+ dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
179
+
180
+ sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95)
181
+ if build_batched and batched == "right":
182
+ sub_idx_p = sub_idx_p[..., None]
183
+ dep = dep[..., None]
184
+ sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape)
185
+ sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma)
186
+
187
+ return m
188
+
189
+ m = build_model(build_batched=batched)
190
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
191
+ m.marginalize(["idx", "sub_idx"])
192
+ assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"]
193
+
194
+ # Test logp
195
+ ref_m = build_model(build_batched=False)
196
+ ref_logp_fn = ref_m.compile_logp(
197
+ vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]]
198
+ )
199
+
200
+ test_point = ref_m.initial_point()
201
+ test_point["dep"] = np.full_like(test_point["dep"], 1000)
202
+ test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
203
+ ref_logp = logsumexp(
204
+ [
205
+ ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}})
206
+ for idx in (0, 1)
207
+ for sub_idxs in itertools.product((0, 1), repeat=5)
208
+ ]
209
+ )
210
+ if batched:
211
+ ref_logp *= 3
212
+
213
+ test_point = m.initial_point()
214
+ test_point["dep"] = np.full_like(test_point["dep"], 1000)
215
+ test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
216
+ logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point)
217
+
218
+ np.testing.assert_almost_equal(logp, ref_logp)
219
+
220
+
221
+ @pytest.mark.parametrize("advanced_indexing", (False, True))
222
+ def test_marginalized_index_as_key(advanced_indexing):
223
+ """Test we can marginalize graphs where indexing is used as a mapping."""
224
+
225
+ w = [0.1, 0.3, 0.6]
226
+ mu = pt.as_tensor([-1, 0, 1])
227
+
228
+ if advanced_indexing:
229
+ y_val = pt.as_tensor([[-1, -1], [0, 1]])
230
+ shape = (2, 2)
231
+ else:
232
+ y_val = -1
233
+ shape = ()
234
+
235
+ with MarginalModel() as m:
236
+ x = pm.Categorical("x", p=w, shape=shape)
237
+ y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
238
+
239
+ m.marginalize(x)
240
+
241
+ marginal_logp = m.compile_logp(sum=False)({})[0]
242
+ ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
243
+
244
+ np.testing.assert_allclose(marginal_logp, ref_logp)
245
+
246
+
247
+ def test_marginalized_index_as_value_and_key():
248
+ """Test we can marginalize graphs were marginalized_rv is indexed."""
249
+
250
+ def build_model(build_batched: bool) -> MarginalModel:
251
+ with MarginalModel() as m:
252
+ if build_batched:
253
+ latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
254
+ else:
255
+ latent_state = pm.math.stack(
256
+ [pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)]
257
+ )
258
+ # latent state is used as the indexed variable
259
+ latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0])
260
+ picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6])
261
+ # picked intensity is used as the indexing variable
262
+ pm.Normal(
263
+ "intensity",
264
+ mu=latent_intensities[:, picked_intensity],
265
+ observed=[0.5, 1.5, 5.0, 15.0],
266
+ )
267
+ return m
268
+
269
+ # We compare with the equivalent but less efficient batched model
270
+ m = build_model(build_batched=True)
271
+ ref_m = build_model(build_batched=False)
272
+
273
+ m.marginalize(["latent_state"])
274
+ ref_m.marginalize([f"latent_state_{i}" for i in range(4)])
275
+ test_point = {"picked_intensity": 1}
276
+ np.testing.assert_allclose(
277
+ m.compile_logp()(test_point),
278
+ ref_m.compile_logp()(test_point),
279
+ )
280
+
281
+ m.marginalize(["picked_intensity"])
282
+ ref_m.marginalize(["picked_intensity"])
283
+ test_point = {}
284
+ np.testing.assert_allclose(
285
+ m.compile_logp()(test_point),
286
+ ref_m.compile_logp()(test_point),
287
+ )
288
+
289
+
290
+ class TestNotSupportedMixedDims:
291
+ """Test lack of support for models where batch dims of marginalized variables are mixed."""
292
+
293
+ def test_mixed_dims_via_transposed_dot(self):
294
+ with MarginalModel() as m:
295
+ idx = pm.Bernoulli("idx", p=0.7, shape=2)
296
+ y = pm.Normal("y", mu=idx @ idx.T)
297
+ with pytest.raises(NotImplementedError):
298
+ m.marginalize(idx)
299
+
300
+ def test_mixed_dims_via_indexing(self):
301
+ mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]])
302
+
303
+ with MarginalModel() as m:
304
+ idx = pm.Bernoulli("idx", p=0.7, shape=2)
305
+ y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx])
306
+ with pytest.raises(NotImplementedError):
307
+ m.marginalize(idx)
308
+
309
+ with MarginalModel() as m:
310
+ idx = pm.Bernoulli("idx", p=0.7, shape=2)
311
+ y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx])
312
+ with pytest.raises(NotImplementedError):
313
+ m.marginalize(idx)
314
+
315
+ with MarginalModel() as m:
316
+ idx = pm.Bernoulli("idx", p=0.7, shape=2)
317
+ mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable(
318
+ mean[None, :][:, idx], 0
319
+ )
320
+ y = pm.Normal("y", mu=mu)
321
+ with pytest.raises(NotImplementedError):
322
+ m.marginalize(idx)
323
+
324
+ with MarginalModel() as m:
325
+ idx = pm.Bernoulli("idx", p=0.7, shape=2)
326
+ y = pm.Normal("y", mu=idx[0] + idx[1])
327
+ with pytest.raises(NotImplementedError):
328
+ m.marginalize(idx)
329
+
330
+ def test_mixed_dims_via_vector_indexing(self):
331
+ with MarginalModel() as m:
332
+ idx = pm.Bernoulli("idx", p=0.7, shape=2)
333
+ y = pm.Normal("y", mu=idx[[0, 1, 0, 0]])
334
+ with pytest.raises(NotImplementedError):
335
+ m.marginalize(idx)
336
+
337
+ with MarginalModel() as m:
338
+ idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2))
339
+ y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)])
340
+ with pytest.raises(NotImplementedError):
341
+ m.marginalize(idx)
342
+
343
+ def test_mixed_dims_via_support_dimension(self):
344
+ with MarginalModel() as m:
345
+ x = pm.Bernoulli("x", p=0.7, shape=3)
346
+ y = pm.Dirichlet("y", a=x * 10 + 1)
347
+ with pytest.raises(NotImplementedError):
348
+ m.marginalize(x)
349
+
350
+ def test_mixed_dims_via_nested_marginalization(self):
351
+ with MarginalModel() as m:
352
+ x = pm.Bernoulli("x", p=0.7, shape=(3,))
353
+ y = pm.Bernoulli("y", p=0.7, shape=(2,))
354
+ z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
355
+
356
+ with pytest.raises(NotImplementedError):
357
+ m.marginalize([x, y])
358
+
359
+
360
+ def test_marginalized_deterministic_and_potential():
361
+ rng = np.random.default_rng(299)
362
+
363
+ with MarginalModel() as m:
364
+ x = pm.Bernoulli("x", p=0.7)
365
+ y = pm.Normal("y", x)
366
+ z = pm.Normal("z", x)
367
+ det = pm.Deterministic("det", y + z)
368
+ pot = pm.Potential("pot", y + z + 1)
369
+
370
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
371
+ m.marginalize([x])
372
+
373
+ y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng)
374
+ np.testing.assert_almost_equal(y_draw + z_draw, det_draw)
375
+ np.testing.assert_almost_equal(det_draw, pot_draw - 1)
376
+
377
+ y_value = m.rvs_to_values[y]
378
+ z_value = m.rvs_to_values[z]
379
+ det_value, pot_value = m.replace_rvs_by_values([det, pot])
380
+ assert set(inputvars([det_value, pot_value])) == {y_value, z_value}
381
+ assert det_value.eval({y_value: 2, z_value: 5}) == 7
382
+ assert pot_value.eval({y_value: 2, z_value: 5}) == 8
383
+
384
+
385
+ def test_not_supported_marginalized_deterministic_and_potential():
386
+ with MarginalModel() as m:
387
+ x = pm.Bernoulli("x", p=0.7)
388
+ y = pm.Normal("y", x)
389
+ det = pm.Deterministic("det", x + y)
390
+
391
+ with pytest.raises(
392
+ NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det"
393
+ ):
394
+ m.marginalize([x])
395
+
396
+ with MarginalModel() as m:
397
+ x = pm.Bernoulli("x", p=0.7)
398
+ y = pm.Normal("y", x)
399
+ pot = pm.Potential("pot", x + y)
400
+
401
+ with pytest.raises(
402
+ NotImplementedError, match="Cannot marginalize x due to dependent Potential pot"
403
+ ):
404
+ m.marginalize([x])
405
+
406
+
407
+ @pytest.mark.parametrize(
408
+ "transform, expected_warning",
409
+ (
410
+ (None, does_not_warn()),
411
+ (UNSET, does_not_warn()),
412
+ (transforms.log, does_not_warn()),
413
+ (transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()),
414
+ (
415
+ transforms.Interval(0, 1),
416
+ pytest.warns(
417
+ UserWarning, match="which depends on the marginalized idx may no longer work"
418
+ ),
419
+ ),
420
+ (
421
+ transforms.Chain([transforms.log, transforms.Interval(0, 1)]),
422
+ pytest.warns(
423
+ UserWarning, match="which depends on the marginalized idx may no longer work"
424
+ ),
425
+ ),
426
+ ),
427
+ )
428
+ def test_marginalized_transforms(transform, expected_warning):
429
+ w = [0.1, 0.3, 0.6]
430
+ data = [0, 5, 10]
431
+ initval = 0.5 # Value that will be negative on the unconstrained space
432
+
433
+ with pm.Model() as m_ref:
434
+ sigma = pm.Mixture(
435
+ "sigma",
436
+ w=w,
437
+ comp_dists=pm.HalfNormal.dist([1, 2, 3]),
438
+ initval=initval,
439
+ default_transform=transform,
440
+ )
441
+ y = pm.Normal("y", 0, sigma, observed=data)
442
+
443
+ with MarginalModel() as m:
444
+ idx = pm.Categorical("idx", p=w)
445
+ sigma = pm.HalfNormal(
446
+ "sigma",
447
+ pt.switch(
448
+ pt.eq(idx, 0),
449
+ 1,
450
+ pt.switch(
451
+ pt.eq(idx, 1),
452
+ 2,
453
+ 3,
454
+ ),
455
+ ),
456
+ initval=initval,
457
+ default_transform=transform,
458
+ )
459
+ y = pm.Normal("y", 0, sigma, observed=data)
460
+
461
+ with expected_warning:
462
+ m.marginalize([idx])
463
+
464
+ ip = m.initial_point()
465
+ if transform is not None:
466
+ if transform is UNSET:
467
+ transform_name = "log"
468
+ else:
469
+ transform_name = transform.name
470
+ assert f"sigma_{transform_name}__" in ip
471
+ np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))
472
+
473
+
474
+ def test_data_container():
475
+ """Test that MarginalModel can handle Data containers."""
476
+ with MarginalModel(coords={"obs": [0]}) as marginal_m:
477
+ x = pm.Data("x", 2.5)
478
+ idx = pm.Bernoulli("idx", p=0.7, dims="obs")
479
+ y = pm.Normal("y", idx * x, dims="obs")
480
+
481
+ marginal_m.marginalize([idx])
482
+
483
+ logp_fn = marginal_m.compile_logp()
484
+
485
+ with pm.Model(coords={"obs": [0]}) as m_ref:
486
+ x = pm.Data("x", 2.5)
487
+ y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs")
488
+
489
+ ref_logp_fn = m_ref.compile_logp()
490
+
491
+ for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1):
492
+ for m in (marginal_m, m_ref):
493
+ m.set_dim("obs", new_length=i, coord_values=tuple(range(i)))
494
+ pm.set_data({"x": x_val}, model=m)
495
+
496
+ ip = marginal_m.initial_point()
497
+ np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))
498
+
499
+
500
+ def test_mutable_indexing_jax_backend():
501
+ pytest.importorskip("jax")
502
+ from pymc.sampling.jax import get_jaxified_logp
503
+
504
+ with MarginalModel() as model:
505
+ data = pm.Data("data", np.zeros(10))
506
+
507
+ cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
508
+ cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5))
509
+
510
+ is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
511
+ pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
512
+ model.marginalize(["is_outlier"])
513
+ get_jaxified_logp(model)
514
+
515
+
516
+ def test_marginal_model_func():
517
+ def create_model(model_class):
518
+ with model_class(coords={"trial": range(10)}) as m:
519
+ idx = pm.Bernoulli("idx", p=0.5, dims="trial")
520
+ mu = pt.where(idx, 1, -1)
521
+ sigma = pm.HalfNormal("sigma")
522
+ y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
523
+ return m
524
+
525
+ marginal_m = marginalize(create_model(pm.Model), ["idx"])
526
+ assert isinstance(marginal_m, MarginalModel)
527
+
528
+ reference_m = create_model(MarginalModel)
529
+ reference_m.marginalize(["idx"])
530
+
531
+ # Check forward graph representation is the same
532
+ marginal_fgraph, _ = fgraph_from_model(marginal_m)
533
+ reference_fgraph, _ = fgraph_from_model(reference_m)
534
+ assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)
535
+
536
+ # Check logp graph is the same
537
+ # This fails because OpFromGraphs comparison is broken
538
+ # assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
539
+ ip = marginal_m.initial_point()
540
+ np.testing.assert_allclose(
541
+ marginal_m.compile_logp()(ip),
542
+ reference_m.compile_logp()(ip),
543
+ )
544
+
545
+
546
+ class TestFullModels:
547
+ @pytest.fixture
548
+ def disaster_model(self):
549
+ # fmt: off
550
+ disaster_data = pd.Series(
551
+ [4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
552
+ 3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
553
+ 2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
554
+ 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
555
+ 0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
556
+ 3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
557
+ 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1]
558
+ )
559
+ # fmt: on
560
+ years = np.arange(1851, 1962)
561
+
562
+ with MarginalModel() as disaster_model:
563
+ switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
564
+ early_rate = pm.Exponential("early_rate", 1.0, initval=3)
565
+ late_rate = pm.Exponential("late_rate", 1.0, initval=1)
566
+ rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
567
+ with pytest.warns(Warning):
568
+ disasters = pm.Poisson("disasters", rate, observed=disaster_data)
569
+
570
+ return disaster_model, years
571
+
572
+ def test_change_point_model(self, disaster_model):
573
+ m, years = disaster_model
574
+
575
+ ip = m.initial_point()
576
+ ip.pop("switchpoint")
577
+ ref_logp_fn = m.compile_logp(
578
+ [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]]
579
+ )
580
+ ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years])
581
+
582
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
583
+ m.marginalize(m["switchpoint"])
584
+
585
+ logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip)
586
+ np.testing.assert_almost_equal(logp, ref_logp)
587
+
588
+ @pytest.mark.slow
589
+ def test_change_point_model_sampling(self, disaster_model):
590
+ m, _ = disaster_model
591
+
592
+ rng = np.random.default_rng(211)
593
+
594
+ with m:
595
+ before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
596
+ sample=("draw", "chain")
597
+ )
598
+
599
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
600
+ m.marginalize([m["switchpoint"]])
601
+
602
+ with m:
603
+ after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
604
+ sample=("draw", "chain")
605
+ )
606
+
607
+ np.testing.assert_allclose(
608
+ before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2
609
+ )
610
+ np.testing.assert_allclose(
611
+ before_marg["late_rate"].mean(), after_marg["late_rate"].mean(), rtol=1e-2
612
+ )
613
+ np.testing.assert_allclose(
614
+ before_marg["disasters_unobserved"].mean(),
615
+ after_marg["disasters_unobserved"].mean(),
616
+ rtol=1e-2,
617
+ )
618
+
619
+ @pytest.mark.parametrize("univariate", (True, False))
620
+ def test_vector_univariate_mixture(self, univariate):
621
+ with MarginalModel() as m:
622
+ idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
623
+
624
+ def dist(idx, size):
625
+ return pm.math.switch(
626
+ pm.math.eq(idx, 0),
627
+ pm.Normal.dist([-10, -10], 1),
628
+ pm.Normal.dist([10, 10], 1),
629
+ )
630
+
631
+ pm.CustomDist("norm", idx, dist=dist)
632
+
633
+ m.marginalize(idx)
634
+ logp_fn = m.compile_logp()
635
+
636
+ if univariate:
637
+ with pm.Model() as ref_m:
638
+ pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
639
+ else:
640
+ with pm.Model() as ref_m:
641
+ pm.Mixture(
642
+ "norm",
643
+ w=[0.5, 0.5],
644
+ comp_dists=[
645
+ pm.MvNormal.dist([-10, -10], np.eye(2)),
646
+ pm.MvNormal.dist([10, 10], np.eye(2)),
647
+ ],
648
+ shape=(2,),
649
+ )
650
+ ref_logp_fn = ref_m.compile_logp()
651
+
652
+ for test_value in (
653
+ [-10, -10],
654
+ [10, 10],
655
+ [-10, 10],
656
+ [-10, 10],
657
+ ):
658
+ pt = {"norm": test_value}
659
+ np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
660
+
661
+ def test_k_censored_clusters_model(self):
662
+ def build_model(build_batched: bool) -> MarginalModel:
663
+ data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
664
+ nobs = data.shape[0]
665
+ n_clusters = 5
666
+ coords = {
667
+ "cluster": range(n_clusters),
668
+ "ndim": ("x", "y"),
669
+ "obs": range(nobs),
670
+ }
671
+ with MarginalModel(coords=coords) as m:
672
+ if build_batched:
673
+ idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
674
+ else:
675
+ idx = pm.math.stack(
676
+ [
677
+ pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters)
678
+ for i in range(nobs)
679
+ ]
680
+ )
681
+
682
+ mu_x = pm.Normal(
683
+ "mu_x",
684
+ dims=["cluster"],
685
+ transform=ordered,
686
+ initval=np.linspace(-1, 1, n_clusters),
687
+ )
688
+ mu_y = pm.Normal("mu_y", dims=["cluster"])
689
+ mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
690
+ mu_indexed = mu[idx, :]
691
+
692
+ sigma = pm.HalfNormal("sigma")
693
+
694
+ y = pm.Censored(
695
+ "y",
696
+ dist=pm.Normal.dist(mu_indexed, sigma),
697
+ lower=-3,
698
+ upper=3,
699
+ observed=data,
700
+ dims=["obs", "ndim"],
701
+ )
702
+
703
+ return m
704
+
705
+ m = build_model(build_batched=True)
706
+ ref_m = build_model(build_batched=False)
707
+
708
+ m.marginalize([m["idx"]])
709
+ ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")])
710
+
711
+ test_point = m.initial_point()
712
+ np.testing.assert_almost_equal(
713
+ m.compile_logp()(test_point),
714
+ ref_m.compile_logp()(test_point),
715
+ )
716
+
717
+
718
+ class TestRecoverMarginals:
719
+ def test_basic(self):
720
+ with MarginalModel() as m:
721
+ sigma = pm.HalfNormal("sigma")
722
+ p = np.array([0.5, 0.2, 0.3])
723
+ k = pm.Categorical("k", p=p)
724
+ mu = np.array([-3.0, 0.0, 3.0])
725
+ mu_ = pt.as_tensor_variable(mu)
726
+ y = pm.Normal("y", mu=mu_[k], sigma=sigma)
727
+
728
+ m.marginalize([k])
729
+
730
+ rng = np.random.default_rng(211)
731
+
732
+ with m:
733
+ prior = pm.sample_prior_predictive(
734
+ draws=20,
735
+ random_seed=rng,
736
+ return_inferencedata=False,
737
+ )
738
+ idata = InferenceData(posterior=dict_to_dataset(prior))
739
+
740
+ idata = m.recover_marginals(idata, return_samples=True)
741
+ post = idata.posterior
742
+ assert "k" in post
743
+ assert "lp_k" in post
744
+ assert post.k.shape == post.y.shape
745
+ assert post.lp_k.shape == (*post.k.shape, len(p))
746
+
747
+ def true_logp(y, sigma):
748
+ y = y.repeat(len(p)).reshape(len(y), -1)
749
+ sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
750
+ return log_softmax(
751
+ np.log(p)
752
+ + norm.logpdf(y, loc=mu, scale=sigma)
753
+ + halfnorm.logpdf(sigma)
754
+ + np.log(sigma),
755
+ axis=1,
756
+ )
757
+
758
+ np.testing.assert_almost_equal(
759
+ true_logp(post.y.values.flatten(), post.sigma.values.flatten()),
760
+ post.lp_k[0].values,
761
+ )
762
+ np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0)
763
+
764
+ def test_coords(self):
765
+ """Test if coords can be recovered with marginalized value had it originally"""
766
+ with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m:
767
+ sigma = pm.HalfNormal("sigma")
768
+ idx = pm.Bernoulli("idx", p=0.75, dims="year")
769
+ x = pm.Normal("x", mu=idx, sigma=sigma, dims="year")
770
+
771
+ m.marginalize([idx])
772
+ rng = np.random.default_rng(211)
773
+
774
+ with m:
775
+ prior = pm.sample_prior_predictive(
776
+ draws=20,
777
+ random_seed=rng,
778
+ return_inferencedata=False,
779
+ )
780
+ idata = InferenceData(
781
+ posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
782
+ )
783
+
784
+ idata = m.recover_marginals(idata, return_samples=True)
785
+ post = idata.posterior
786
+ assert post.idx.dims == ("chain", "draw", "year")
787
+ assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
788
+
789
+ def test_batched(self):
790
+ """Test that marginalization works for batched random variables"""
791
+ with MarginalModel() as m:
792
+ sigma = pm.HalfNormal("sigma")
793
+ idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
794
+ y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
795
+
796
+ m.marginalize([idx])
797
+
798
+ rng = np.random.default_rng(211)
799
+
800
+ with m:
801
+ prior = pm.sample_prior_predictive(
802
+ draws=20,
803
+ random_seed=rng,
804
+ return_inferencedata=False,
805
+ )
806
+ idata = InferenceData(
807
+ posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
808
+ )
809
+
810
+ idata = m.recover_marginals(idata, return_samples=True)
811
+ post = idata.posterior
812
+ assert post["y"].shape == (1, 20, 2, 3)
813
+ assert post["idx"].shape == (1, 20, 3, 2)
814
+ assert post["lp_idx"].shape == (1, 20, 3, 2, 2)
815
+
816
+ def test_nested(self):
817
+ """Test that marginalization works when there are nested marginalized RVs"""
818
+
819
+ with MarginalModel() as m:
820
+ idx = pm.Bernoulli("idx", p=0.75)
821
+ sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
822
+ sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
823
+
824
+ m.marginalize([idx, sub_idx])
825
+
826
+ rng = np.random.default_rng(211)
827
+
828
+ with m:
829
+ prior = pm.sample_prior_predictive(
830
+ draws=20,
831
+ random_seed=rng,
832
+ return_inferencedata=False,
833
+ )
834
+ idata = InferenceData(posterior=dict_to_dataset(prior))
835
+
836
+ idata = m.recover_marginals(idata, return_samples=True)
837
+ post = idata.posterior
838
+ assert "idx" in post
839
+ assert "lp_idx" in post
840
+ assert post.idx.shape == post.y.shape
841
+ assert post.lp_idx.shape == (*post.idx.shape, 2)
842
+ assert "sub_idx" in post
843
+ assert "lp_sub_idx" in post
844
+ assert post.sub_idx.shape == post.y.shape
845
+ assert post.lp_sub_idx.shape == (*post.sub_idx.shape, 2)
846
+
847
+ def true_idx_logp(y):
848
+ idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
849
+ idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
850
+ return log_softmax(np.stack([idx_0, idx_1]).T, axis=1)
851
+
852
+ np.testing.assert_almost_equal(
853
+ true_idx_logp(post.y.values.flatten()),
854
+ post.lp_idx[0].values,
855
+ )
856
+
857
+ def true_sub_idx_logp(y):
858
+ sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1))
859
+ sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
860
+ return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1)
861
+
862
+ np.testing.assert_almost_equal(
863
+ true_sub_idx_logp(post.y.values.flatten()),
864
+ post.lp_sub_idx[0].values,
865
+ )
866
+ np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0)
867
+ np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0)