pymc-extras 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,25 +9,28 @@ import pytensor.tensor as pt
9
9
  import pytest
10
10
 
11
11
  from arviz import InferenceData, dict_to_dataset
12
+ from pymc import Model, draw
12
13
  from pymc.distributions import transforms
13
14
  from pymc.distributions.transforms import ordered
14
- from pymc.model.fgraph import fgraph_from_model
15
- from pymc.pytensorf import inputvars
15
+ from pymc.initial_point import make_initial_point_expression
16
+ from pymc.pytensorf import constant_fold, inputvars
16
17
  from pymc.util import UNSET
17
18
  from scipy.special import log_softmax, logsumexp
18
19
  from scipy.stats import halfnorm, norm
19
20
 
21
+ from pymc_extras.model.marginal.distributions import MarginalRV
20
22
  from pymc_extras.model.marginal.marginal_model import (
21
- MarginalModel,
22
23
  marginalize,
24
+ recover_marginals,
25
+ unmarginalize,
23
26
  )
24
- from tests.utils import equal_computations_up_to_root
27
+ from pymc_extras.utils.model_equivalence import equivalent_models
25
28
 
26
29
 
27
30
  def test_basic_marginalized_rv():
28
31
  data = [2] * 5
29
32
 
30
- with MarginalModel() as m:
33
+ with Model() as m:
31
34
  sigma = pm.HalfNormal("sigma")
32
35
  idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6])
33
36
  mu = pt.switch(
@@ -42,79 +45,105 @@ def test_basic_marginalized_rv():
42
45
  y = pm.Normal("y", mu=mu, sigma=sigma)
43
46
  z = pm.Normal("z", y, observed=data)
44
47
 
45
- m.marginalize([idx])
46
- assert idx not in m.free_RVs
47
- assert [rv.name for rv in m.marginalized_rvs] == ["idx"]
48
+ marginal_m = marginalize(m, [idx])
49
+ assert isinstance(marginal_m["y"].owner.op, MarginalRV)
50
+ assert ["idx"] not in [rv.name for rv in marginal_m.free_RVs]
51
+
52
+ # Test forward draws
53
+ y_draws, z_draws = draw(
54
+ [marginal_m["y"], marginal_m["z"]],
55
+ # Make sigma very small to make draws deterministic
56
+ givens={marginal_m["sigma"]: 0.001},
57
+ draws=1000,
58
+ random_seed=54,
59
+ )
60
+ assert sorted(np.unique(y_draws.round()) == [-1.0, 0.0, 1.0])
61
+ assert z_draws[y_draws < 0].mean() < z_draws[y_draws > 0].mean()
62
+
63
+ # Test initial_point
64
+ ips = make_initial_point_expression(
65
+ # Use basic_RVs to include the observed RV
66
+ free_rvs=marginal_m.basic_RVs,
67
+ rvs_to_transforms=marginal_m.rvs_to_transforms,
68
+ initval_strategies={},
69
+ )
70
+ # After simplification, we should have only constants in the graph (expect alloc which isn't constant folded):
71
+ ip_sigma, ip_y, ip_z = constant_fold(ips)
72
+ np.testing.assert_allclose(ip_sigma, 1.0)
73
+ np.testing.assert_allclose(ip_y, 1.0)
74
+ np.testing.assert_allclose(ip_z, np.full((5,), 1.0))
75
+
76
+ marginal_ip = marginal_m.initial_point()
77
+ expected_ip = m.initial_point()
78
+ expected_ip.pop("idx")
79
+ assert marginal_ip == expected_ip
48
80
 
49
81
  # Test logp
50
- with pm.Model() as m_ref:
82
+ with pm.Model() as ref_m:
51
83
  sigma = pm.HalfNormal("sigma")
52
84
  y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma)
53
85
  z = pm.Normal("z", y, observed=data)
54
86
 
55
- test_point = m_ref.initial_point()
56
- ref_logp = m_ref.compile_logp()(test_point)
57
- ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point)
58
-
59
- # Assert we can marginalize and unmarginalize internally non-destructively
60
- for i in range(3):
61
- np.testing.assert_almost_equal(
62
- m.compile_logp()(test_point),
63
- ref_logp,
64
- )
65
- np.testing.assert_almost_equal(
66
- m.compile_dlogp([m["y"]])(test_point),
67
- ref_dlogp,
68
- )
87
+ np.testing.assert_almost_equal(
88
+ marginal_m.compile_logp()(marginal_ip),
89
+ ref_m.compile_logp()(marginal_ip),
90
+ )
91
+ np.testing.assert_almost_equal(
92
+ marginal_m.compile_dlogp([marginal_m["y"]])(marginal_ip),
93
+ ref_m.compile_dlogp([ref_m["y"]])(marginal_ip),
94
+ )
69
95
 
70
96
 
71
97
  def test_one_to_one_marginalized_rvs():
72
98
  """Test case with multiple, independent marginalized RVs."""
73
- with MarginalModel() as m:
99
+ with Model() as m:
74
100
  sigma = pm.HalfNormal("sigma")
75
101
  idx1 = pm.Bernoulli("idx1", p=0.75)
76
102
  x = pm.Normal("x", mu=idx1, sigma=sigma)
77
103
  idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,))
78
104
  y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,))
79
105
 
80
- m.marginalize([idx1, idx2])
81
- m["x"].owner is not m["y"].owner
82
- _m = m.clone()._marginalize()
83
- _m["x"].owner is not _m["y"].owner
106
+ marginal_m = marginalize(m, [idx1, idx2])
107
+ assert isinstance(marginal_m["x"].owner.op, MarginalRV)
108
+ assert isinstance(marginal_m["y"].owner.op, MarginalRV)
109
+ assert marginal_m["x"].owner is not marginal_m["y"].owner
84
110
 
85
- with pm.Model() as m_ref:
111
+ with pm.Model() as ref_m:
86
112
  sigma = pm.HalfNormal("sigma")
87
113
  x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma)
88
114
  y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,))
89
115
 
90
116
  # Test logp
91
- test_point = m_ref.initial_point()
92
- x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point)
93
- x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point)
117
+ test_point = ref_m.initial_point()
118
+ x_logp, y_logp = marginal_m.compile_logp(vars=[marginal_m["x"], marginal_m["y"]], sum=False)(
119
+ test_point
120
+ )
121
+ x_ref_log, y_ref_logp = ref_m.compile_logp(vars=[ref_m["x"], ref_m["y"]], sum=False)(test_point)
94
122
  np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum())
95
123
  np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
96
124
 
97
125
 
98
126
  def test_one_to_many_marginalized_rvs():
99
127
  """Test that marginalization works when there is more than one dependent RV"""
100
- with MarginalModel() as m:
128
+ with Model() as m:
101
129
  sigma = pm.HalfNormal("sigma")
102
130
  idx = pm.Bernoulli("idx", p=0.75)
103
131
  x = pm.Normal("x", mu=idx, sigma=sigma)
104
132
  y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,))
105
133
 
106
- ref_logp_x_y_fn = m.compile_logp([idx, x, y])
107
-
108
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
109
- m.marginalize([idx])
134
+ marginal_m = marginalize(m, [idx])
110
135
 
111
- m["x"].owner is not m["y"].owner
112
- _m = m.clone()._marginalize()
113
- _m["x"].owner is _m["y"].owner
136
+ marginal_x = marginal_m["x"]
137
+ marginal_y = marginal_m["y"]
138
+ assert isinstance(marginal_x.owner.op, MarginalRV)
139
+ assert isinstance(marginal_y.owner.op, MarginalRV)
140
+ assert marginal_x.owner is marginal_y.owner
114
141
 
115
- tp = m.initial_point()
142
+ ref_logp_x_y_fn = m.compile_logp([idx, x, y])
143
+ tp = marginal_m.initial_point()
116
144
  ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)])
117
- logp_x_y = m.compile_logp([x, y])(tp)
145
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
146
+ logp_x_y = marginal_m.compile_logp([marginal_x, marginal_y])(tp)
118
147
  np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
119
148
 
120
149
 
@@ -122,7 +151,7 @@ def test_one_to_many_unaligned_marginalized_rvs():
122
151
  """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
123
152
 
124
153
  def build_model(build_batched: bool):
125
- with MarginalModel() as m:
154
+ with Model() as m:
126
155
  if build_batched:
127
156
  idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
128
157
  else:
@@ -134,44 +163,41 @@ def test_one_to_many_unaligned_marginalized_rvs():
134
163
 
135
164
  return m
136
165
 
137
- m = build_model(build_batched=True)
138
- ref_m = build_model(build_batched=False)
166
+ marginal_m = marginalize(build_model(build_batched=True), ["idx"])
167
+ ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(6)])
139
168
 
140
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
141
- m.marginalize(["idx"])
142
- ref_m.marginalize([f"idx_{i}" for i in range(6)])
169
+ test_point = marginal_m.initial_point()
143
170
 
144
- test_point = m.initial_point()
145
- np.testing.assert_allclose(
146
- m.compile_logp()(test_point),
147
- ref_m.compile_logp()(test_point),
148
- )
171
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
172
+ np.testing.assert_allclose(
173
+ marginal_m.compile_logp()(test_point),
174
+ ref_m.compile_logp()(test_point),
175
+ )
149
176
 
150
177
 
151
178
  def test_many_to_one_marginalized_rvs():
152
179
  """Test when random variables depend on multiple marginalized variables"""
153
- with MarginalModel() as m:
180
+ with Model() as m:
154
181
  x = pm.Bernoulli("x", 0.1)
155
182
  y = pm.Bernoulli("y", 0.3)
156
183
  z = pm.DiracDelta("z", c=x + y)
157
184
 
158
- m.marginalize([x, y])
159
- logp = m.compile_logp()
185
+ logp_fn = marginalize(m, [x, y]).compile_logp()
160
186
 
161
- np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7)
162
- np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
163
- np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3)
187
+ np.testing.assert_allclose(np.exp(logp_fn({"z": 0})), 0.9 * 0.7)
188
+ np.testing.assert_allclose(np.exp(logp_fn({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7)
189
+ np.testing.assert_allclose(np.exp(logp_fn({"z": 2})), 0.1 * 0.3)
164
190
 
165
191
 
166
192
  @pytest.mark.parametrize("batched", (False, "left", "right"))
167
193
  def test_nested_marginalized_rvs(batched):
168
194
  """Test that marginalization works when there are nested marginalized RVs"""
169
195
 
170
- def build_model(build_batched: bool) -> MarginalModel:
196
+ def build_model(build_batched: bool) -> Model:
171
197
  idx_shape = (3,) if build_batched else ()
172
198
  sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
173
199
 
174
- with MarginalModel() as m:
200
+ with Model() as m:
175
201
  sigma = pm.HalfNormal("sigma")
176
202
 
177
203
  idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape)
@@ -186,10 +212,33 @@ def test_nested_marginalized_rvs(batched):
186
212
 
187
213
  return m
188
214
 
189
- m = build_model(build_batched=batched)
190
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
191
- m.marginalize(["idx", "sub_idx"])
192
- assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"]
215
+ marginal_m = marginalize(build_model(build_batched=batched), ["idx", "sub_idx"])
216
+ assert all(rv.name not in ("idx", "sub_idx") for rv in marginal_m.free_RVs)
217
+
218
+ # Test forward draws and initial_point, shouldn't depend on batching, so we only test one case
219
+ if not batched:
220
+ # Test forward draws
221
+ dep_draws, sub_dep_draws = draw(
222
+ [marginal_m["dep"], marginal_m["sub_dep"]],
223
+ # Make sigma very small to make draws deterministic
224
+ givens={marginal_m["sigma"]: 0.001},
225
+ draws=1000,
226
+ random_seed=214,
227
+ )
228
+ assert sorted(np.unique(dep_draws.round()) == [-1000.0, 1000.0])
229
+ assert sorted(np.unique(sub_dep_draws.round()) == [-1000.0, -900.0, 1000.0, 1100.0])
230
+
231
+ # Test initial_point
232
+ ips = make_initial_point_expression(
233
+ free_rvs=marginal_m.free_RVs,
234
+ rvs_to_transforms=marginal_m.rvs_to_transforms,
235
+ initval_strategies={},
236
+ )
237
+ # After simplification, we should have only constants in the graph
238
+ ip_sigma, ip_dep, ip_sub_dep = constant_fold(ips)
239
+ np.testing.assert_allclose(ip_sigma, 1.0)
240
+ np.testing.assert_allclose(ip_dep, 1000.0)
241
+ np.testing.assert_allclose(ip_sub_dep, np.full((5,), 1100.0))
193
242
 
194
243
  # Test logp
195
244
  ref_m = build_model(build_batched=False)
@@ -210,14 +259,70 @@ def test_nested_marginalized_rvs(batched):
210
259
  if batched:
211
260
  ref_logp *= 3
212
261
 
213
- test_point = m.initial_point()
262
+ test_point = marginal_m.initial_point()
214
263
  test_point["dep"] = np.full_like(test_point["dep"], 1000)
215
264
  test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100)
216
- logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point)
265
+
266
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
267
+ logp = marginal_m.compile_logp(vars=[marginal_m["dep"], marginal_m["sub_dep"]])(test_point)
217
268
 
218
269
  np.testing.assert_almost_equal(logp, ref_logp)
219
270
 
220
271
 
272
+ def test_interdependent_rvs():
273
+ """Test Marginalization when dependent RVs are interdependent."""
274
+ with Model() as m:
275
+ idx = pm.Bernoulli("idx", p=0.75)
276
+ x = pm.Normal("x", mu=idx * 2, sigma=1e-3)
277
+ # Y depends on both x and idx
278
+ y = pm.Normal("y", mu=x * idx * 2, sigma=1e-3)
279
+
280
+ marginal_m = marginalize(m, "idx")
281
+
282
+ marginal_x = marginal_m["x"]
283
+ marginal_y = marginal_m["y"]
284
+ assert isinstance(marginal_x.owner.op, MarginalRV)
285
+ assert isinstance(marginal_y.owner.op, MarginalRV)
286
+ assert marginal_x.owner is marginal_y.owner
287
+
288
+ # Test forward draws
289
+ x_draws, y_draws = draw([marginal_x, marginal_y], draws=1000, random_seed=54)
290
+ assert sorted(np.unique(x_draws.round())) == [0, 2]
291
+ assert sorted(np.unique(y_draws.round())) == [0, 4]
292
+ assert np.unique(y_draws[x_draws < 1].round()) == [0]
293
+ assert np.unique(y_draws[x_draws > 1].round()) == [4]
294
+
295
+ # Test initial_point
296
+ ips = make_initial_point_expression(
297
+ free_rvs=marginal_m.free_RVs,
298
+ rvs_to_transforms={},
299
+ initval_strategies={},
300
+ )
301
+ # After simplification, we should have only constants in the graph
302
+ ip_x, ip_y = constant_fold(ips)
303
+ np.testing.assert_allclose(ip_x, 2.0)
304
+ np.testing.assert_allclose(ip_y, 4.0)
305
+
306
+ # Test custom initval strategy
307
+ ips = make_initial_point_expression(
308
+ # Test that order does not matter
309
+ free_rvs=marginal_m.free_RVs[::-1],
310
+ rvs_to_transforms={},
311
+ initval_strategies={marginal_x: pt.constant(5.0)},
312
+ )
313
+ ip_y, ip_x = constant_fold(ips)
314
+ np.testing.assert_allclose(ip_x, 5.0)
315
+ np.testing.assert_allclose(ip_y, 10.0)
316
+
317
+ # Test logp
318
+ test_point = marginal_m.initial_point()
319
+ ref_logp_fn = m.compile_logp([m["idx"], m["x"], m["y"]])
320
+ ref_logp = logsumexp([ref_logp_fn({**test_point, **{"idx": idx}}) for idx in (0, 1)])
321
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
322
+ logp = marginal_m.compile_logp([marginal_m["x"], marginal_m["y"]])(test_point)
323
+ np.testing.assert_almost_equal(logp, ref_logp)
324
+
325
+
221
326
  @pytest.mark.parametrize("advanced_indexing", (False, True))
222
327
  def test_marginalized_index_as_key(advanced_indexing):
223
328
  """Test we can marginalize graphs where indexing is used as a mapping."""
@@ -232,13 +337,13 @@ def test_marginalized_index_as_key(advanced_indexing):
232
337
  y_val = -1
233
338
  shape = ()
234
339
 
235
- with MarginalModel() as m:
340
+ with Model() as m:
236
341
  x = pm.Categorical("x", p=w, shape=shape)
237
342
  y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
238
343
 
239
- m.marginalize(x)
344
+ marginal_m = marginalize(m, x)
240
345
 
241
- marginal_logp = m.compile_logp(sum=False)({})[0]
346
+ marginal_logp = marginal_m.compile_logp(sum=False)({})[0]
242
347
  ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
243
348
 
244
349
  np.testing.assert_allclose(marginal_logp, ref_logp)
@@ -247,8 +352,8 @@ def test_marginalized_index_as_key(advanced_indexing):
247
352
  def test_marginalized_index_as_value_and_key():
248
353
  """Test we can marginalize graphs were marginalized_rv is indexed."""
249
354
 
250
- def build_model(build_batched: bool) -> MarginalModel:
251
- with MarginalModel() as m:
355
+ def build_model(build_batched: bool) -> Model:
356
+ with Model() as m:
252
357
  if build_batched:
253
358
  latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
254
359
  else:
@@ -270,16 +375,16 @@ def test_marginalized_index_as_value_and_key():
270
375
  m = build_model(build_batched=True)
271
376
  ref_m = build_model(build_batched=False)
272
377
 
273
- m.marginalize(["latent_state"])
274
- ref_m.marginalize([f"latent_state_{i}" for i in range(4)])
378
+ m = marginalize(m, ["latent_state"])
379
+ ref_m = marginalize(ref_m, [f"latent_state_{i}" for i in range(4)])
275
380
  test_point = {"picked_intensity": 1}
276
381
  np.testing.assert_allclose(
277
382
  m.compile_logp()(test_point),
278
383
  ref_m.compile_logp()(test_point),
279
384
  )
280
385
 
281
- m.marginalize(["picked_intensity"])
282
- ref_m.marginalize(["picked_intensity"])
386
+ m = marginalize(m, ["picked_intensity"])
387
+ ref_m = marginalize(ref_m, ["picked_intensity"])
283
388
  test_point = {}
284
389
  np.testing.assert_allclose(
285
390
  m.compile_logp()(test_point),
@@ -291,99 +396,99 @@ class TestNotSupportedMixedDims:
291
396
  """Test lack of support for models where batch dims of marginalized variables are mixed."""
292
397
 
293
398
  def test_mixed_dims_via_transposed_dot(self):
294
- with MarginalModel() as m:
399
+ with Model() as m:
295
400
  idx = pm.Bernoulli("idx", p=0.7, shape=2)
296
401
  y = pm.Normal("y", mu=idx @ idx.T)
297
- with pytest.raises(NotImplementedError):
298
- m.marginalize(idx)
402
+
403
+ with pytest.raises(NotImplementedError):
404
+ marginalize(m, idx)
299
405
 
300
406
  def test_mixed_dims_via_indexing(self):
301
407
  mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]])
302
408
 
303
- with MarginalModel() as m:
409
+ with Model() as m:
304
410
  idx = pm.Bernoulli("idx", p=0.7, shape=2)
305
411
  y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx])
306
- with pytest.raises(NotImplementedError):
307
- m.marginalize(idx)
412
+ with pytest.raises(NotImplementedError):
413
+ marginalize(m, idx)
308
414
 
309
- with MarginalModel() as m:
415
+ with Model() as m:
310
416
  idx = pm.Bernoulli("idx", p=0.7, shape=2)
311
417
  y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx])
312
- with pytest.raises(NotImplementedError):
313
- m.marginalize(idx)
418
+ with pytest.raises(NotImplementedError):
419
+ marginalize(m, idx)
314
420
 
315
- with MarginalModel() as m:
421
+ with Model() as m:
316
422
  idx = pm.Bernoulli("idx", p=0.7, shape=2)
317
423
  mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable(
318
424
  mean[None, :][:, idx], 0
319
425
  )
320
426
  y = pm.Normal("y", mu=mu)
321
- with pytest.raises(NotImplementedError):
322
- m.marginalize(idx)
427
+ with pytest.raises(NotImplementedError):
428
+ marginalize(m, idx)
323
429
 
324
- with MarginalModel() as m:
430
+ with Model() as m:
325
431
  idx = pm.Bernoulli("idx", p=0.7, shape=2)
326
432
  y = pm.Normal("y", mu=idx[0] + idx[1])
327
- with pytest.raises(NotImplementedError):
328
- m.marginalize(idx)
433
+ with pytest.raises(NotImplementedError):
434
+ marginalize(m, idx)
329
435
 
330
436
  def test_mixed_dims_via_vector_indexing(self):
331
- with MarginalModel() as m:
437
+ with Model() as m:
332
438
  idx = pm.Bernoulli("idx", p=0.7, shape=2)
333
439
  y = pm.Normal("y", mu=idx[[0, 1, 0, 0]])
334
- with pytest.raises(NotImplementedError):
335
- m.marginalize(idx)
440
+ with pytest.raises(NotImplementedError):
441
+ marginalize(m, idx)
336
442
 
337
- with MarginalModel() as m:
443
+ with Model() as m:
338
444
  idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2))
339
445
  y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)])
340
- with pytest.raises(NotImplementedError):
341
- m.marginalize(idx)
446
+ with pytest.raises(NotImplementedError):
447
+ marginalize(m, idx)
342
448
 
343
449
  def test_mixed_dims_via_support_dimension(self):
344
- with MarginalModel() as m:
450
+ with Model() as m:
345
451
  x = pm.Bernoulli("x", p=0.7, shape=3)
346
452
  y = pm.Dirichlet("y", a=x * 10 + 1)
347
- with pytest.raises(NotImplementedError):
348
- m.marginalize(x)
453
+ with pytest.raises(NotImplementedError):
454
+ marginalize(m, x)
349
455
 
350
456
  def test_mixed_dims_via_nested_marginalization(self):
351
- with MarginalModel() as m:
457
+ with Model() as m:
352
458
  x = pm.Bernoulli("x", p=0.7, shape=(3,))
353
459
  y = pm.Bernoulli("y", p=0.7, shape=(2,))
354
460
  z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
355
461
 
356
- with pytest.raises(NotImplementedError):
357
- m.marginalize([x, y])
462
+ with pytest.raises(NotImplementedError):
463
+ marginalize(m, [x, y])
358
464
 
359
465
 
360
466
  def test_marginalized_deterministic_and_potential():
361
467
  rng = np.random.default_rng(299)
362
468
 
363
- with MarginalModel() as m:
469
+ with Model() as m:
364
470
  x = pm.Bernoulli("x", p=0.7)
365
471
  y = pm.Normal("y", x)
366
472
  z = pm.Normal("z", x)
367
473
  det = pm.Deterministic("det", y + z)
368
474
  pot = pm.Potential("pot", y + z + 1)
369
475
 
370
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
371
- m.marginalize([x])
476
+ marginal_m = marginalize(m, [x])
372
477
 
373
478
  y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng)
374
479
  np.testing.assert_almost_equal(y_draw + z_draw, det_draw)
375
480
  np.testing.assert_almost_equal(det_draw, pot_draw - 1)
376
481
 
377
- y_value = m.rvs_to_values[y]
378
- z_value = m.rvs_to_values[z]
379
- det_value, pot_value = m.replace_rvs_by_values([det, pot])
482
+ y_value = marginal_m.rvs_to_values[marginal_m["y"]]
483
+ z_value = marginal_m.rvs_to_values[marginal_m["z"]]
484
+ det_value, pot_value = marginal_m.replace_rvs_by_values([marginal_m["det"], marginal_m["pot"]])
380
485
  assert set(inputvars([det_value, pot_value])) == {y_value, z_value}
381
486
  assert det_value.eval({y_value: 2, z_value: 5}) == 7
382
487
  assert pot_value.eval({y_value: 2, z_value: 5}) == 8
383
488
 
384
489
 
385
490
  def test_not_supported_marginalized_deterministic_and_potential():
386
- with MarginalModel() as m:
491
+ with Model() as m:
387
492
  x = pm.Bernoulli("x", p=0.7)
388
493
  y = pm.Normal("y", x)
389
494
  det = pm.Deterministic("det", x + y)
@@ -391,9 +496,9 @@ def test_not_supported_marginalized_deterministic_and_potential():
391
496
  with pytest.raises(
392
497
  NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det"
393
498
  ):
394
- m.marginalize([x])
499
+ marginalize(m, [x])
395
500
 
396
- with MarginalModel() as m:
501
+ with Model() as m:
397
502
  x = pm.Bernoulli("x", p=0.7)
398
503
  y = pm.Normal("y", x)
399
504
  pot = pm.Potential("pot", x + y)
@@ -401,7 +506,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
401
506
  with pytest.raises(
402
507
  NotImplementedError, match="Cannot marginalize x due to dependent Potential pot"
403
508
  ):
404
- m.marginalize([x])
509
+ marginalize(m, [x])
405
510
 
406
511
 
407
512
  @pytest.mark.parametrize(
@@ -410,15 +515,15 @@ def test_not_supported_marginalized_deterministic_and_potential():
410
515
  (None, does_not_warn()),
411
516
  (UNSET, does_not_warn()),
412
517
  (transforms.log, does_not_warn()),
413
- (transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()),
518
+ (transforms.Chain([transforms.logodds, transforms.log]), does_not_warn()),
414
519
  (
415
- transforms.Interval(0, 1),
520
+ transforms.Interval(0, 2),
416
521
  pytest.warns(
417
522
  UserWarning, match="which depends on the marginalized idx may no longer work"
418
523
  ),
419
524
  ),
420
525
  (
421
- transforms.Chain([transforms.log, transforms.Interval(0, 1)]),
526
+ transforms.Chain([transforms.log, transforms.Interval(-1, 1)]),
422
527
  pytest.warns(
423
528
  UserWarning, match="which depends on the marginalized idx may no longer work"
424
529
  ),
@@ -428,7 +533,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
428
533
  def test_marginalized_transforms(transform, expected_warning):
429
534
  w = [0.1, 0.3, 0.6]
430
535
  data = [0, 5, 10]
431
- initval = 0.5 # Value that will be negative on the unconstrained space
536
+ initval = 0.7 # Value that will be negative on the unconstrained space
432
537
 
433
538
  with pm.Model() as m_ref:
434
539
  sigma = pm.Mixture(
@@ -440,7 +545,7 @@ def test_marginalized_transforms(transform, expected_warning):
440
545
  )
441
546
  y = pm.Normal("y", 0, sigma, observed=data)
442
547
 
443
- with MarginalModel() as m:
548
+ with Model() as m:
444
549
  idx = pm.Categorical("idx", p=w)
445
550
  sigma = pm.HalfNormal(
446
551
  "sigma",
@@ -453,32 +558,32 @@ def test_marginalized_transforms(transform, expected_warning):
453
558
  3,
454
559
  ),
455
560
  ),
456
- initval=initval,
457
561
  default_transform=transform,
458
562
  )
459
563
  y = pm.Normal("y", 0, sigma, observed=data)
460
564
 
461
565
  with expected_warning:
462
- m.marginalize([idx])
566
+ marginal_m = marginalize(m, [idx])
463
567
 
464
- ip = m.initial_point()
568
+ marginal_m.set_initval(marginal_m["sigma"], initval)
569
+ ip = marginal_m.initial_point()
465
570
  if transform is not None:
466
571
  if transform is UNSET:
467
572
  transform_name = "log"
468
573
  else:
469
574
  transform_name = transform.name
470
- assert f"sigma_{transform_name}__" in ip
471
- np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))
575
+ assert -np.inf < ip[f"sigma_{transform_name}__"] < 0.0
576
+ np.testing.assert_allclose(marginal_m.compile_logp()(ip), m_ref.compile_logp()(ip))
472
577
 
473
578
 
474
579
  def test_data_container():
475
580
  """Test that MarginalModel can handle Data containers."""
476
- with MarginalModel(coords={"obs": [0]}) as marginal_m:
581
+ with Model(coords={"obs": [0]}) as m:
477
582
  x = pm.Data("x", 2.5)
478
583
  idx = pm.Bernoulli("idx", p=0.7, dims="obs")
479
584
  y = pm.Normal("y", idx * x, dims="obs")
480
585
 
481
- marginal_m.marginalize([idx])
586
+ marginal_m = marginalize(m, [idx])
482
587
 
483
588
  logp_fn = marginal_m.compile_logp()
484
589
 
@@ -501,7 +606,7 @@ def test_mutable_indexing_jax_backend():
501
606
  pytest.importorskip("jax")
502
607
  from pymc.sampling.jax import get_jaxified_logp
503
608
 
504
- with MarginalModel() as model:
609
+ with Model() as model:
505
610
  data = pm.Data("data", np.zeros(10))
506
611
 
507
612
  cat_effect = pm.Normal("cat_effect", sigma=1, shape=5)
@@ -509,38 +614,8 @@ def test_mutable_indexing_jax_backend():
509
614
 
510
615
  is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10)
511
616
  pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data)
512
- model.marginalize(["is_outlier"])
513
- get_jaxified_logp(model)
514
-
515
-
516
- def test_marginal_model_func():
517
- def create_model(model_class):
518
- with model_class(coords={"trial": range(10)}) as m:
519
- idx = pm.Bernoulli("idx", p=0.5, dims="trial")
520
- mu = pt.where(idx, 1, -1)
521
- sigma = pm.HalfNormal("sigma")
522
- y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10)
523
- return m
524
-
525
- marginal_m = marginalize(create_model(pm.Model), ["idx"])
526
- assert isinstance(marginal_m, MarginalModel)
527
-
528
- reference_m = create_model(MarginalModel)
529
- reference_m.marginalize(["idx"])
530
-
531
- # Check forward graph representation is the same
532
- marginal_fgraph, _ = fgraph_from_model(marginal_m)
533
- reference_fgraph, _ = fgraph_from_model(reference_m)
534
- assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs)
535
-
536
- # Check logp graph is the same
537
- # This fails because OpFromGraphs comparison is broken
538
- # assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()])
539
- ip = marginal_m.initial_point()
540
- np.testing.assert_allclose(
541
- marginal_m.compile_logp()(ip),
542
- reference_m.compile_logp()(ip),
543
- )
617
+ marginal_model = marginalize(model, ["is_outlier"])
618
+ get_jaxified_logp(marginal_model)
544
619
 
545
620
 
546
621
  class TestFullModels:
@@ -559,10 +634,10 @@ class TestFullModels:
559
634
  # fmt: on
560
635
  years = np.arange(1851, 1962)
561
636
 
562
- with MarginalModel() as disaster_model:
637
+ with Model() as disaster_model:
563
638
  switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max())
564
- early_rate = pm.Exponential("early_rate", 1.0, initval=3)
565
- late_rate = pm.Exponential("late_rate", 1.0, initval=1)
639
+ early_rate = pm.Exponential("early_rate", 1.0)
640
+ late_rate = pm.Exponential("late_rate", 1.0)
566
641
  rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
567
642
  with pytest.warns(Warning):
568
643
  disasters = pm.Poisson("disasters", rate, observed=disaster_data)
@@ -573,17 +648,21 @@ class TestFullModels:
573
648
  m, years = disaster_model
574
649
 
575
650
  ip = m.initial_point()
651
+ ip["late_rate_log__"] += 1.0 # Make early and endpoint ip different
652
+
576
653
  ip.pop("switchpoint")
577
654
  ref_logp_fn = m.compile_logp(
578
655
  [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]]
579
656
  )
580
657
  ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years])
581
658
 
582
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
583
- m.marginalize(m["switchpoint"])
659
+ marginal_m = marginalize(m, m["switchpoint"])
584
660
 
585
- logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip)
586
- np.testing.assert_almost_equal(logp, ref_logp)
661
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
662
+ marginal_m_logp = marginal_m.compile_logp(
663
+ [marginal_m["disasters_observed"], marginal_m["disasters_unobserved"]]
664
+ )(ip)
665
+ np.testing.assert_almost_equal(marginal_m_logp, ref_logp)
587
666
 
588
667
  @pytest.mark.slow
589
668
  def test_change_point_model_sampling(self, disaster_model):
@@ -596,13 +675,13 @@ class TestFullModels:
596
675
  sample=("draw", "chain")
597
676
  )
598
677
 
599
- with pytest.warns(UserWarning, match="There are multiple dependent variables"):
600
- m.marginalize([m["switchpoint"]])
678
+ marginal_m = marginalize(m, "switchpoint")
601
679
 
602
- with m:
603
- after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
604
- sample=("draw", "chain")
605
- )
680
+ with marginal_m:
681
+ with pytest.warns(UserWarning, match="There are multiple dependent variables"):
682
+ after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(
683
+ sample=("draw", "chain")
684
+ )
606
685
 
607
686
  np.testing.assert_allclose(
608
687
  before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2
@@ -618,7 +697,7 @@ class TestFullModels:
618
697
 
619
698
  @pytest.mark.parametrize("univariate", (True, False))
620
699
  def test_vector_univariate_mixture(self, univariate):
621
- with MarginalModel() as m:
700
+ with Model() as m:
622
701
  idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
623
702
 
624
703
  def dist(idx, size):
@@ -630,8 +709,8 @@ class TestFullModels:
630
709
 
631
710
  pm.CustomDist("norm", idx, dist=dist)
632
711
 
633
- m.marginalize(idx)
634
- logp_fn = m.compile_logp()
712
+ marginal_m = marginalize(m, idx)
713
+ logp_fn = marginal_m.compile_logp()
635
714
 
636
715
  if univariate:
637
716
  with pm.Model() as ref_m:
@@ -659,16 +738,17 @@ class TestFullModels:
659
738
  np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
660
739
 
661
740
  def test_k_censored_clusters_model(self):
662
- def build_model(build_batched: bool) -> MarginalModel:
663
- data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
664
- nobs = data.shape[0]
665
- n_clusters = 5
741
+ data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
742
+ nobs = data.shape[0]
743
+ n_clusters = 5
744
+
745
+ def build_model(build_batched: bool) -> Model:
666
746
  coords = {
667
747
  "cluster": range(n_clusters),
668
748
  "ndim": ("x", "y"),
669
749
  "obs": range(nobs),
670
750
  }
671
- with MarginalModel(coords=coords) as m:
751
+ with Model(coords=coords) as m:
672
752
  if build_batched:
673
753
  idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
674
754
  else:
@@ -683,7 +763,6 @@ class TestFullModels:
683
763
  "mu_x",
684
764
  dims=["cluster"],
685
765
  transform=ordered,
686
- initval=np.linspace(-1, 1, n_clusters),
687
766
  )
688
767
  mu_y = pm.Normal("mu_y", dims=["cluster"])
689
768
  mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim)
@@ -702,12 +781,10 @@ class TestFullModels:
702
781
 
703
782
  return m
704
783
 
705
- m = build_model(build_batched=True)
706
- ref_m = build_model(build_batched=False)
707
-
708
- m.marginalize([m["idx"]])
709
- ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")])
784
+ m = marginalize(build_model(build_batched=True), "idx")
785
+ m.set_initval(m["mu_x"], np.linspace(-1, 1, n_clusters))
710
786
 
787
+ ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(nobs)])
711
788
  test_point = m.initial_point()
712
789
  np.testing.assert_almost_equal(
713
790
  m.compile_logp()(test_point),
@@ -715,9 +792,32 @@ class TestFullModels:
715
792
  )
716
793
 
717
794
 
795
+ def test_unmarginalize():
796
+ with pm.Model() as m:
797
+ idx = pm.Bernoulli("idx", p=0.5)
798
+ sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx])
799
+ x = pm.Normal("x", mu=(idx + sub_idx) - 1)
800
+
801
+ marginal_m = marginalize(m, [idx, sub_idx])
802
+ assert not equivalent_models(marginal_m, m)
803
+
804
+ unmarginal_m = unmarginalize(marginal_m)
805
+ assert equivalent_models(unmarginal_m, m)
806
+
807
+ unmarginal_idx_explicit = unmarginalize(marginal_m, ("idx", "sub_idx"))
808
+ assert equivalent_models(unmarginal_idx_explicit, m)
809
+
810
+ # Test partial unmarginalize
811
+ unmarginal_idx = unmarginalize(marginal_m, "idx")
812
+ assert equivalent_models(unmarginal_idx, marginalize(m, "sub_idx"))
813
+
814
+ unmarginal_sub_idx = unmarginalize(marginal_m, "sub_idx")
815
+ assert equivalent_models(unmarginal_sub_idx, marginalize(m, "idx"))
816
+
817
+
718
818
  class TestRecoverMarginals:
719
819
  def test_basic(self):
720
- with MarginalModel() as m:
820
+ with Model() as m:
721
821
  sigma = pm.HalfNormal("sigma")
722
822
  p = np.array([0.5, 0.2, 0.3])
723
823
  k = pm.Categorical("k", p=p)
@@ -725,11 +825,11 @@ class TestRecoverMarginals:
725
825
  mu_ = pt.as_tensor_variable(mu)
726
826
  y = pm.Normal("y", mu=mu_[k], sigma=sigma)
727
827
 
728
- m.marginalize([k])
828
+ marginal_m = marginalize(m, [k])
729
829
 
730
830
  rng = np.random.default_rng(211)
731
831
 
732
- with m:
832
+ with marginal_m:
733
833
  prior = pm.sample_prior_predictive(
734
834
  draws=20,
735
835
  random_seed=rng,
@@ -737,7 +837,7 @@ class TestRecoverMarginals:
737
837
  )
738
838
  idata = InferenceData(posterior=dict_to_dataset(prior))
739
839
 
740
- idata = m.recover_marginals(idata, return_samples=True)
840
+ idata = recover_marginals(marginal_m, idata, return_samples=True)
741
841
  post = idata.posterior
742
842
  assert "k" in post
743
843
  assert "lp_k" in post
@@ -763,15 +863,15 @@ class TestRecoverMarginals:
763
863
 
764
864
  def test_coords(self):
765
865
  """Test if coords can be recovered with marginalized value had it originally"""
766
- with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m:
866
+ with Model(coords={"year": [1990, 1991, 1992]}) as m:
767
867
  sigma = pm.HalfNormal("sigma")
768
868
  idx = pm.Bernoulli("idx", p=0.75, dims="year")
769
869
  x = pm.Normal("x", mu=idx, sigma=sigma, dims="year")
770
870
 
771
- m.marginalize([idx])
871
+ marginal_m = marginalize(m, [idx])
772
872
  rng = np.random.default_rng(211)
773
873
 
774
- with m:
874
+ with marginal_m:
775
875
  prior = pm.sample_prior_predictive(
776
876
  draws=20,
777
877
  random_seed=rng,
@@ -781,23 +881,23 @@ class TestRecoverMarginals:
781
881
  posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
782
882
  )
783
883
 
784
- idata = m.recover_marginals(idata, return_samples=True)
884
+ idata = recover_marginals(marginal_m, idata, return_samples=True)
785
885
  post = idata.posterior
786
886
  assert post.idx.dims == ("chain", "draw", "year")
787
887
  assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
788
888
 
789
889
  def test_batched(self):
790
890
  """Test that marginalization works for batched random variables"""
791
- with MarginalModel() as m:
891
+ with Model() as m:
792
892
  sigma = pm.HalfNormal("sigma")
793
893
  idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
794
894
  y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
795
895
 
796
- m.marginalize([idx])
896
+ marginal_m = marginalize(m, [idx])
797
897
 
798
898
  rng = np.random.default_rng(211)
799
899
 
800
- with m:
900
+ with marginal_m:
801
901
  prior = pm.sample_prior_predictive(
802
902
  draws=20,
803
903
  random_seed=rng,
@@ -807,7 +907,7 @@ class TestRecoverMarginals:
807
907
  posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
808
908
  )
809
909
 
810
- idata = m.recover_marginals(idata, return_samples=True)
910
+ idata = recover_marginals(marginal_m, idata, return_samples=True)
811
911
  post = idata.posterior
812
912
  assert post["y"].shape == (1, 20, 2, 3)
813
913
  assert post["idx"].shape == (1, 20, 3, 2)
@@ -816,16 +916,16 @@ class TestRecoverMarginals:
816
916
  def test_nested(self):
817
917
  """Test that marginalization works when there are nested marginalized RVs"""
818
918
 
819
- with MarginalModel() as m:
919
+ with Model() as m:
820
920
  idx = pm.Bernoulli("idx", p=0.75)
821
921
  sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95))
822
922
  sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0)
823
923
 
824
- m.marginalize([idx, sub_idx])
924
+ marginal_m = marginalize(m, [idx, sub_idx])
825
925
 
826
926
  rng = np.random.default_rng(211)
827
927
 
828
- with m:
928
+ with marginal_m:
829
929
  prior = pm.sample_prior_predictive(
830
930
  draws=20,
831
931
  random_seed=rng,
@@ -833,7 +933,7 @@ class TestRecoverMarginals:
833
933
  )
834
934
  idata = InferenceData(posterior=dict_to_dataset(prior))
835
935
 
836
- idata = m.recover_marginals(idata, return_samples=True)
936
+ idata = recover_marginals(marginal_m, idata, return_samples=True)
837
937
  post = idata.posterior
838
938
  assert "idx" in post
839
939
  assert "lp_idx" in post