pymc-extras 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  import warnings
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Union
5
4
 
6
5
  import numpy as np
7
6
  import pymc
@@ -13,21 +12,42 @@ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
13
12
  from pymc.distributions.transforms import Chain
14
13
  from pymc.logprob.transforms import IntervalTransform
15
14
  from pymc.model import Model
16
- from pymc.pytensorf import compile_pymc, constant_fold
17
- from pymc.util import RandomState, _get_seeds_per_chain, treedict
15
+ from pymc.model.fgraph import (
16
+ ModelFreeRV,
17
+ ModelValuedVar,
18
+ fgraph_from_model,
19
+ model_free_rv,
20
+ model_from_fgraph,
21
+ )
22
+ from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
23
+ from pymc.util import RandomState, _get_seeds_per_chain
24
+ from pytensor import In, Out
18
25
  from pytensor.compile import SharedVariable
19
- from pytensor.graph import FunctionGraph, clone_replace, graph_inputs
20
- from pytensor.graph.replace import vectorize_graph
26
+ from pytensor.graph import (
27
+ FunctionGraph,
28
+ Variable,
29
+ clone_replace,
30
+ graph_inputs,
31
+ graph_replace,
32
+ node_rewriter,
33
+ vectorize_graph,
34
+ )
35
+ from pytensor.graph.rewriting.basic import in2out
21
36
  from pytensor.tensor import TensorVariable
22
- from pytensor.tensor.special import log_softmax
23
37
 
24
38
  __all__ = ["MarginalModel", "marginalize"]
25
39
 
40
+ from pytensor.tensor.random.type import RandomType
41
+ from pytensor.tensor.special import log_softmax
42
+
26
43
  from pymc_extras.distributions import DiscreteMarkovChain
27
44
  from pymc_extras.model.marginal.distributions import (
28
45
  MarginalDiscreteMarkovChainRV,
29
46
  MarginalFiniteDiscreteRV,
47
+ MarginalRV,
48
+ NonSeparableLogpWarning,
30
49
  get_domain_of_finite_discrete_rv,
50
+ inline_ofg_outputs,
31
51
  reduce_batch_dependent_logps,
32
52
  )
33
53
  from pymc_extras.model.marginal.graph_analysis import (
@@ -87,479 +107,452 @@ class MarginalModel(Model):
87
107
  """
88
108
 
89
109
  def __init__(self, *args, **kwargs):
90
- super().__init__(*args, **kwargs)
91
- self.marginalized_rvs = []
92
- self._marginalized_named_vars_to_dims = {}
110
+ raise TypeError(
111
+ "MarginalModel was deprecated in favor of `marginalize` which now returns a PyMC model"
112
+ )
93
113
 
94
- def _delete_rv_mappings(self, rv: TensorVariable) -> None:
95
- """Remove all model mappings referring to rv
96
114
 
97
- This can be used to "delete" an RV from a model
98
- """
99
- assert rv in self.basic_RVs, "rv is not part of the Model"
115
+ def _warn_interval_transform(rv_to_marginalize, replaced_vars: Sequence[ModelValuedVar]) -> None:
116
+ for replaced_var in replaced_vars:
117
+ if not isinstance(replaced_var.owner.op, ModelValuedVar):
118
+ raise TypeError(f"{replaced_var} is not a ModelValuedVar")
100
119
 
101
- name = rv.name
102
- self.named_vars.pop(name)
103
- if name in self.named_vars_to_dims:
104
- self.named_vars_to_dims.pop(name)
120
+ if not isinstance(replaced_var.owner.op, ModelFreeRV):
121
+ continue
105
122
 
106
- value = self.rvs_to_values.pop(rv)
107
- self.values_to_rvs.pop(value)
123
+ if replaced_var is rv_to_marginalize:
124
+ continue
108
125
 
109
- self.rvs_to_transforms.pop(rv)
110
- if rv in self.free_RVs:
111
- self.free_RVs.remove(rv)
112
- self.rvs_to_initial_values.pop(rv)
113
- else:
114
- self.observed_RVs.remove(rv)
115
-
116
- def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None:
117
- """Transfer model mappings from old_rv to new_rv"""
118
-
119
- assert old_rv in self.basic_RVs, "old_rv is not part of the Model"
120
- assert new_rv not in self.basic_RVs, "new_rv is already part of the Model"
121
-
122
- self.named_vars.pop(old_rv.name)
123
- new_rv.name = old_rv.name
124
- self.named_vars[new_rv.name] = new_rv
125
- if old_rv in self.named_vars_to_dims:
126
- self._RV_dims[new_rv] = self._RV_dims.pop(old_rv)
127
-
128
- value = self.rvs_to_values.pop(old_rv)
129
- self.rvs_to_values[new_rv] = value
130
- self.values_to_rvs[value] = new_rv
131
-
132
- self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
133
- if old_rv in self.free_RVs:
134
- index = self.free_RVs.index(old_rv)
135
- self.free_RVs.pop(index)
136
- self.free_RVs.insert(index, new_rv)
137
- self.rvs_to_initial_values[new_rv] = self.rvs_to_initial_values.pop(old_rv)
138
- elif old_rv in self.observed_RVs:
139
- index = self.observed_RVs.index(old_rv)
140
- self.observed_RVs.pop(index)
141
- self.observed_RVs.insert(index, new_rv)
142
-
143
- def _marginalize(self, user_warnings=False):
144
- fg = FunctionGraph(outputs=self.basic_RVs + self.marginalized_rvs, clone=False)
145
-
146
- toposort = fg.toposort()
147
- rvs_left_to_marginalize = self.marginalized_rvs
148
- for rv_to_marginalize in sorted(
149
- self.marginalized_rvs,
150
- key=lambda rv: toposort.index(rv.owner),
151
- reverse=True,
126
+ transform = replaced_var.owner.op.transform
127
+
128
+ if isinstance(transform, IntervalTransform) or (
129
+ isinstance(transform, Chain)
130
+ and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)
152
131
  ):
153
- # Check that no deterministics or potentials dependend on the rv to marginalize
154
- for det in self.deterministics:
155
- if is_conditional_dependent(
156
- det, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
157
- ):
158
- raise NotImplementedError(
159
- f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}"
160
- )
161
- for pot in self.potentials:
162
- if is_conditional_dependent(
163
- pot, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
164
- ):
165
- raise NotImplementedError(
166
- f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
167
- )
168
-
169
- old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
170
- fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
132
+ warnings.warn(
133
+ f"The transform {transform} for the variable {replaced_var}, which depends on the "
134
+ f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
135
+ UserWarning,
171
136
  )
172
137
 
173
- if user_warnings and len(new_rvs) > 2:
174
- warnings.warn(
175
- "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
176
- f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}",
177
- UserWarning,
138
+
139
+ def _unique(seq: Sequence) -> list:
140
+ """Copied from https://stackoverflow.com/a/480227"""
141
+ seen = set()
142
+ seen_add = seen.add
143
+ return [x for x in seq if not (x in seen or seen_add(x))]
144
+
145
+
146
+ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
147
+ """Marginalize a subset of variables in a PyMC model.
148
+
149
+ This creates a class of `MarginalModel` from an existing `Model`, with the specified
150
+ variables marginalized.
151
+
152
+ See documentation for `MarginalModel` for more information.
153
+
154
+ Parameters
155
+ ----------
156
+ model : Model
157
+ PyMC model to marginalize. Original variables well be cloned.
158
+ rvs_to_marginalize : Sequence[TensorVariable]
159
+ Variables to marginalize in the returned model.
160
+
161
+ Returns
162
+ -------
163
+ marginal_model: MarginalModel
164
+ Marginal model with the specified variables marginalized.
165
+ """
166
+ if isinstance(rvs_to_marginalize, str | Variable):
167
+ rvs_to_marginalize = (rvs_to_marginalize,)
168
+
169
+ rvs_to_marginalize = [model[rv] if isinstance(rv, str) else rv for rv in rvs_to_marginalize]
170
+
171
+ if not rvs_to_marginalize:
172
+ return model
173
+
174
+ for rv_to_marginalize in rvs_to_marginalize:
175
+ if rv_to_marginalize not in model.free_RVs:
176
+ raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model")
177
+
178
+ rv_op = rv_to_marginalize.owner.op
179
+ if isinstance(rv_op, DiscreteMarkovChain):
180
+ if rv_op.n_lags > 1:
181
+ raise NotImplementedError(
182
+ "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
183
+ )
184
+ if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
185
+ raise NotImplementedError(
186
+ "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
178
187
  )
188
+ elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
189
+ raise NotImplementedError(
190
+ f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
191
+ )
179
192
 
180
- rvs_left_to_marginalize.remove(rv_to_marginalize)
181
-
182
- for old_rv, new_rv in zip(old_rvs, new_rvs):
183
- new_rv.name = old_rv.name
184
- if old_rv in self.marginalized_rvs:
185
- idx = self.marginalized_rvs.index(old_rv)
186
- self.marginalized_rvs.pop(idx)
187
- self.marginalized_rvs.insert(idx, new_rv)
188
- if old_rv in self.basic_RVs:
189
- self._transfer_rv_mappings(old_rv, new_rv)
190
- if user_warnings:
191
- # Interval transforms for dependent variable won't work for non-constant bounds because
192
- # the RV inputs are now different and may depend on another RV that also depends on the
193
- # same marginalized RV
194
- transform = self.rvs_to_transforms[new_rv]
195
- if isinstance(transform, IntervalTransform) or (
196
- isinstance(transform, Chain)
197
- and any(
198
- isinstance(tr, IntervalTransform) for tr in transform.transform_list
199
- )
200
- ):
201
- warnings.warn(
202
- f"The transform {transform} for the variable {old_rv}, which depends on the "
203
- f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
204
- UserWarning,
205
- )
206
- return self
207
-
208
- def _logp(self, *args, **kwargs):
209
- return super().logp(*args, **kwargs)
210
-
211
- def logp(self, vars=None, **kwargs):
212
- m = self.clone()._marginalize()
213
- if vars is not None:
214
- if not isinstance(vars, Sequence):
215
- vars = (vars,)
216
- vars = [m[var.name] for var in vars]
217
- return m._logp(vars=vars, **kwargs)
218
-
219
- @staticmethod
220
- def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel":
221
- new_model = MarginalModel(coords=model.coords)
222
- if isinstance(model, MarginalModel):
223
- marginalized_rvs = model.marginalized_rvs
224
- marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
225
- else:
226
- marginalized_rvs = []
227
- marginalized_named_vars_to_dims = {}
228
-
229
- model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
230
- data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
231
- vars = model_vars + data_vars
232
- cloned_vars = clone_replace(vars)
233
- vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
234
- new_model.vars_to_clone = vars_to_clone
235
-
236
- new_model.named_vars = treedict(
237
- {name: vars_to_clone[var] for name, var in model.named_vars.items()}
238
- )
239
- new_model.named_vars_to_dims = model.named_vars_to_dims
240
- new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()}
241
- new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()}
242
- new_model.rvs_to_transforms = {
243
- vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items()
244
- }
245
- new_model.rvs_to_initial_values = {
246
- vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items()
247
- }
248
- new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
249
- new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
250
- new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
251
- new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]
252
-
253
- new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
254
- new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
255
- return new_model
256
-
257
- def clone(self):
258
- return self.from_model(self)
259
-
260
- def marginalize(
261
- self,
262
- rvs_to_marginalize: ModelRVs,
193
+ fg, memo = fgraph_from_model(model)
194
+ rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]
195
+ toposort = fg.toposort()
196
+
197
+ for rv_to_marginalize in sorted(
198
+ rvs_to_marginalize,
199
+ key=lambda rv: toposort.index(rv.owner),
200
+ reverse=True,
263
201
  ):
264
- if not isinstance(rvs_to_marginalize, Sequence):
265
- rvs_to_marginalize = (rvs_to_marginalize,)
266
- rvs_to_marginalize = [
267
- self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
268
- ]
202
+ all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)]
203
+
204
+ dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
205
+ if not dependent_rvs:
206
+ # TODO: This should at most be a warning, not an error
207
+ raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
208
+
209
+ # Issue warning for IntervalTransform on dependent RVs
210
+ for dependent_rv in dependent_rvs:
211
+ transform = dependent_rv.owner.op.transform
269
212
 
270
- for rv_to_marginalize in rvs_to_marginalize:
271
- if rv_to_marginalize not in self.free_RVs:
272
- raise ValueError(
273
- f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
213
+ if isinstance(transform, IntervalTransform) or (
214
+ isinstance(transform, Chain)
215
+ and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)
216
+ ):
217
+ warnings.warn(
218
+ f"The transform {transform} for the variable {dependent_rv}, which depends on the "
219
+ f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
220
+ UserWarning,
274
221
  )
275
222
 
276
- rv_op = rv_to_marginalize.owner.op
277
- if isinstance(rv_op, DiscreteMarkovChain):
278
- if rv_op.n_lags > 1:
279
- raise NotImplementedError(
280
- "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
281
- )
282
- if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
283
- raise NotImplementedError(
284
- "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
285
- )
286
- elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
223
+ # Check that no deterministics or potentials depend on the rv to marginalize
224
+ for det in model.deterministics:
225
+ if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs):
287
226
  raise NotImplementedError(
288
- f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
227
+ f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}"
228
+ )
229
+ for pot in model.potentials:
230
+ if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs):
231
+ raise NotImplementedError(
232
+ f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
289
233
  )
290
234
 
291
- if rv_to_marginalize.name in self.named_vars_to_dims:
292
- dims = self.named_vars_to_dims[rv_to_marginalize.name]
293
- self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims
294
-
295
- self._delete_rv_mappings(rv_to_marginalize)
296
- self.marginalized_rvs.append(rv_to_marginalize)
297
-
298
- # Raise errors and warnings immediately
299
- self.clone()._marginalize(user_warnings=True)
300
-
301
- def _to_transformed(self):
302
- """Create a function from the untransformed space to the transformed space"""
303
- transformed_rvs = []
304
- transformed_names = []
305
-
306
- for rv in self.free_RVs:
307
- transform = self.rvs_to_transforms.get(rv)
308
- if transform is None:
309
- transformed_rvs.append(rv)
310
- transformed_names.append(rv.name)
311
- else:
312
- transformed_rv = transform.forward(rv, *rv.owner.inputs)
313
- transformed_rvs.append(transformed_rv)
314
- transformed_names.append(self.rvs_to_values[rv].name)
315
-
316
- fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
317
- return fn, transformed_names
318
-
319
- def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]):
320
- for rv in rvs_to_unmarginalize:
321
- if isinstance(rv, str):
322
- rv = self[rv]
323
- self.marginalized_rvs.remove(rv)
324
- if rv.name in self._marginalized_named_vars_to_dims:
325
- dims = self._marginalized_named_vars_to_dims.pop(rv.name)
326
- else:
327
- dims = None
328
- self.register_rv(rv, name=rv.name, dims=dims)
329
-
330
- def recover_marginals(
331
- self,
332
- idata: InferenceData,
333
- var_names: Sequence[str] | None = None,
334
- return_samples: bool = True,
335
- extend_inferencedata: bool = True,
336
- random_seed: RandomState = None,
337
- ):
338
- """Computes posterior log-probabilities and samples of marginalized variables
339
- conditioned on parameters of the model given InferenceData with posterior group
235
+ marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
236
+ other_direct_rv_ancestors = [
237
+ rv
238
+ for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
239
+ if rv is not rv_to_marginalize
240
+ ]
241
+ input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
340
242
 
341
- When there are multiple marginalized variables, each marginalized variable is
342
- conditioned on both the parameters and the other variables still marginalized
243
+ replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
343
244
 
344
- All log-probabilities are within the transformed space
245
+ return model_from_fgraph(fg, mutate_fgraph=True)
345
246
 
346
- Parameters
347
- ----------
348
- idata : InferenceData
349
- InferenceData with posterior group
350
- var_names : sequence of str, optional
351
- List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
352
- return_samples : bool, default True
353
- If True, also return samples of the marginalized variables
354
- extend_inferencedata : bool, default True
355
- Whether to extend the original InferenceData or return a new one
356
- random_seed: int, array-like of int or SeedSequence, optional
357
- Seed used to generating samples
358
247
 
359
- Returns
360
- -------
361
- idata : InferenceData
362
- InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
248
+ @node_rewriter(tracks=[MarginalRV])
249
+ def local_unmarginalize(fgraph, node):
250
+ unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs)
251
+ rngs = [rng for rng in dependent_rvs_and_rngs if isinstance(rng.type, RandomType)]
252
+ dependent_rvs = [rv for rv in dependent_rvs_and_rngs if rv not in rngs]
363
253
 
364
- .. code-block:: python
254
+ # Wrap the marginalized RV in a FreeRV
255
+ # TODO: Preserve dims and transform in MarginalRV
256
+ value = unmarginalized_rv.clone()
257
+ fgraph.add_input(value)
258
+ transform = None
259
+ unmarginalized_free_rv = model_free_rv(unmarginalized_rv, value, transform, *node.op.dims)
365
260
 
366
- import pymc as pm
367
- from pymc_extras import MarginalModel
261
+ # Replace references to the marginalized RV with the FreeRV in the dependent RVs
262
+ dependent_rvs = graph_replace(dependent_rvs, {unmarginalized_rv: unmarginalized_free_rv})
368
263
 
369
- with MarginalModel() as m:
370
- p = pm.Beta("p", 1, 1)
371
- x = pm.Bernoulli("x", p=p, shape=(3,))
372
- y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
264
+ return [unmarginalized_free_rv, *dependent_rvs, *rngs]
373
265
 
374
- m.marginalize([x])
375
266
 
376
- idata = pm.sample()
377
- m.recover_marginals(idata, var_names=["x"])
267
+ unmarginalize_rewrite = in2out(local_unmarginalize, ignore_newtrees=False)
378
268
 
379
269
 
380
- """
381
- if var_names is None:
382
- var_names = [var.name for var in self.marginalized_rvs]
270
+ def unmarginalize(model: Model, rvs_to_unmarginalize: str | Sequence[str] | None = None) -> Model:
271
+ """Unmarginalize a subset of variables in a PyMC model.
383
272
 
384
- var_names = [var if isinstance(var, str) else var.name for var in var_names]
385
- vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
386
- missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
387
- if missing_names:
388
- raise ValueError(f"Unrecognized var_names: {missing_names}")
389
273
 
390
- if return_samples and random_seed is not None:
391
- seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
274
+ Parameters
275
+ ----------
276
+ model : Model
277
+ PyMC model to unmarginalize. Original variables well be cloned.
278
+ rvs_to_unmarginalize : str or sequence of str, optional
279
+ Variables to unmarginalize in the returned model. If None, all variables are
280
+ unmarginalized.
281
+
282
+ Returns
283
+ -------
284
+ unmarginal_model: Model
285
+ Model with the specified variables unmarginalized.
286
+ """
287
+
288
+ # Unmarginalize all the MarginalRVs
289
+ fg, memo = fgraph_from_model(model)
290
+ unmarginalize_rewrite(fg)
291
+ unmarginalized_model = model_from_fgraph(fg, mutate_fgraph=True)
292
+ if rvs_to_unmarginalize is None:
293
+ return unmarginalized_model
294
+
295
+ # Re-marginalize the variables we want to keep marginalized
296
+ if not isinstance(rvs_to_unmarginalize, list | tuple):
297
+ rvs_to_unmarginalize = (rvs_to_unmarginalize,)
298
+ rvs_to_unmarginalize = set(rvs_to_unmarginalize)
299
+
300
+ old_free_rv_names = set(rv.name for rv in model.free_RVs)
301
+ new_free_rv_names = set(
302
+ rv.name for rv in unmarginalized_model.free_RVs if rv.name not in old_free_rv_names
303
+ )
304
+ if rvs_to_unmarginalize - new_free_rv_names:
305
+ raise ValueError(
306
+ f"Unrecognized rvs_to_unmarginalize: {rvs_to_unmarginalize - new_free_rv_names}"
307
+ )
308
+ rvs_to_keep_marginalized = tuple(new_free_rv_names - rvs_to_unmarginalize)
309
+ return marginalize(unmarginalized_model, rvs_to_keep_marginalized)
310
+
311
+
312
+ def transform_posterior_pts(model, posterior_pts):
313
+ """Create a function from the untransformed space to the transformed space"""
314
+ # TODO: This should be a utility in PyMC
315
+ transformed_rvs = []
316
+ transformed_names = []
317
+
318
+ for rv in model.free_RVs:
319
+ transform = model.rvs_to_transforms.get(rv)
320
+ if transform is None:
321
+ transformed_rvs.append(rv)
322
+ transformed_names.append(rv.name)
392
323
  else:
393
- seeds = [None] * len(vars_to_recover)
324
+ transformed_rv = transform.forward(rv, *rv.owner.inputs)
325
+ transformed_rvs.append(transformed_rv)
326
+ transformed_names.append(model.rvs_to_values[rv].name)
394
327
 
395
- posterior = idata.posterior
328
+ fn = compile_pymc(
329
+ inputs=[In(inp, borrow=True) for inp in model.free_RVs],
330
+ outputs=[Out(out, borrow=True) for out in transformed_rvs],
331
+ )
332
+ fn.trust_input = True
396
333
 
397
- # Remove Deterministics
398
- posterior_values = posterior[
399
- [rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
400
- ]
334
+ # TODO: This should work with vectorized inputs
335
+ return [dict(zip(transformed_names, fn(**point))) for point in posterior_pts]
401
336
 
402
- sample_dims = ("chain", "draw")
403
- posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
404
337
 
405
- # Handle Transforms
406
- transform_fn, transform_names = self._to_transformed()
338
+ def recover_marginals(
339
+ model: Model,
340
+ idata: InferenceData,
341
+ var_names: Sequence[str] | None = None,
342
+ return_samples: bool = True,
343
+ extend_inferencedata: bool = True,
344
+ random_seed: RandomState = None,
345
+ ):
346
+ """Computes posterior log-probabilities and samples of marginalized variables
347
+ conditioned on parameters of the model given InferenceData with posterior group
407
348
 
408
- def transform_input(inputs):
409
- return dict(zip(transform_names, transform_fn(inputs)))
349
+ When there are multiple marginalized variables, each marginalized variable is
350
+ conditioned on both the parameters and the other variables still marginalized
410
351
 
411
- posterior_pts = [transform_input(vs) for vs in posterior_pts]
352
+ All log-probabilities are within the transformed space
412
353
 
413
- rv_dict = {}
414
- rv_dims = {}
415
- for seed, marginalized_rv in zip(seeds, vars_to_recover):
416
- supported_dists = (Bernoulli, Categorical, DiscreteUniform)
417
- if not isinstance(marginalized_rv.owner.op, supported_dists):
418
- raise NotImplementedError(
419
- f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
420
- f"Supported distribution include {supported_dists}"
421
- )
354
+ Parameters
355
+ ----------
356
+ model: Model
357
+ PyMC model with marginalized variables to recover
358
+ idata : InferenceData
359
+ InferenceData with posterior group
360
+ var_names : sequence of str, optional
361
+ List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
362
+ return_samples : bool, default True
363
+ If True, also return samples of the marginalized variables
364
+ extend_inferencedata : bool, default True
365
+ Whether to extend the original InferenceData or return a new one
366
+ random_seed: int, array-like of int or SeedSequence, optional
367
+ Seed used to generating samples
422
368
 
423
- m = self.clone()
424
- marginalized_rv = m.vars_to_clone[marginalized_rv]
425
- m.unmarginalize([marginalized_rv])
426
- dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
427
- logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False)
369
+ Returns
370
+ -------
371
+ idata : InferenceData
372
+ InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
428
373
 
429
- # Handle batch dims for marginalized value and its dependent RVs
430
- dependent_rvs_dim_connections = subgraph_batch_dim_connection(
431
- marginalized_rv, dependent_rvs
432
- )
433
- marginalized_logp, *dependent_logps = logps
434
- joint_logp = marginalized_logp + reduce_batch_dependent_logps(
435
- dependent_rvs_dim_connections,
436
- [dependent_var.owner.op for dependent_var in dependent_rvs],
437
- dependent_logps,
438
- )
374
+ .. code-block:: python
439
375
 
440
- marginalized_value = m.rvs_to_values[marginalized_rv]
441
- other_values = [v for v in m.value_vars if v is not marginalized_value]
442
-
443
- rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
444
- rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
445
- rv_domain_tensor = pt.moveaxis(
446
- pt.full(
447
- (*rv_shape, len(rv_domain)),
448
- rv_domain,
449
- dtype=marginalized_rv.dtype,
450
- ),
451
- -1,
452
- 0,
453
- )
376
+ import pymc as pm
377
+ from pymc_extras import MarginalModel
454
378
 
455
- batched_joint_logp = vectorize_graph(
456
- joint_logp,
457
- replace={marginalized_value: rv_domain_tensor},
458
- )
459
- batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1)
460
-
461
- joint_logp_norm = log_softmax(batched_joint_logp, axis=-1)
462
- if return_samples:
463
- rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp)
464
- if isinstance(marginalized_rv.owner.op, DiscreteUniform):
465
- rv_draws += rv_domain[0]
466
- outputs = [joint_logp_norm, rv_draws]
467
- else:
468
- outputs = joint_logp_norm
469
-
470
- rv_loglike_fn = compile_pymc(
471
- inputs=other_values,
472
- outputs=outputs,
473
- on_unused_input="ignore",
474
- random_seed=seed,
379
+ with MarginalModel() as m:
380
+ p = pm.Beta("p", 1, 1)
381
+ x = pm.Bernoulli("x", p=p, shape=(3,))
382
+ y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
383
+
384
+ m.marginalize([x])
385
+
386
+ idata = pm.sample()
387
+ m.recover_marginals(idata, var_names=["x"])
388
+
389
+
390
+ """
391
+ unmarginal_model = unmarginalize(model)
392
+
393
+ # Find the names of the marginalized variables
394
+ model_var_names = set(rv.name for rv in model.free_RVs)
395
+ marginalized_rv_names = [
396
+ rv.name for rv in unmarginal_model.free_RVs if rv.name not in model_var_names
397
+ ]
398
+
399
+ if var_names is None:
400
+ var_names = marginalized_rv_names
401
+
402
+ var_names = [var if isinstance(var, str) else var.name for var in var_names]
403
+ var_names_to_recover = [name for name in marginalized_rv_names if name in var_names]
404
+ missing_names = [name for name in var_names_to_recover if name not in marginalized_rv_names]
405
+ if missing_names:
406
+ raise ValueError(f"Unrecognized var_names: {missing_names}")
407
+
408
+ if return_samples and random_seed is not None:
409
+ seeds = _get_seeds_per_chain(random_seed, len(var_names_to_recover))
410
+ else:
411
+ seeds = [None] * len(var_names_to_recover)
412
+
413
+ posterior_pts, stacked_dims = dataset_to_point_list(
414
+ # Remove Deterministics
415
+ idata.posterior[[rv.name for rv in model.free_RVs]],
416
+ sample_dims=("chain", "draw"),
417
+ )
418
+ transformed_posterior_pts = transform_posterior_pts(model, posterior_pts)
419
+
420
+ rv_dict = {}
421
+ rv_dims = {}
422
+ for seed, var_name_to_recover in zip(seeds, var_names_to_recover):
423
+ var_to_recover = unmarginal_model[var_name_to_recover]
424
+ supported_dists = (Bernoulli, Categorical, DiscreteUniform)
425
+ if not isinstance(var_to_recover.owner.op, supported_dists):
426
+ raise NotImplementedError(
427
+ f"RV with distribution {var_to_recover.owner.op} cannot be recovered. "
428
+ f"Supported distribution include {supported_dists}"
475
429
  )
476
430
 
477
- logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
431
+ other_marginalized_rvs_names = marginalized_rv_names.copy()
432
+ other_marginalized_rvs_names.remove(var_name_to_recover)
433
+ dependent_rvs = [
434
+ rv
435
+ for rv in find_conditional_dependent_rvs(var_to_recover, unmarginal_model.basic_RVs)
436
+ if rv.name not in other_marginalized_rvs_names
437
+ ]
438
+ # Handle batch dims for marginalized value and its dependent RVs
439
+ dependent_rvs_dim_connections = subgraph_batch_dim_connection(var_to_recover, dependent_rvs)
440
+
441
+ marginalized_model = marginalize(unmarginal_model, other_marginalized_rvs_names)
478
442
 
479
- if return_samples:
480
- logps, samples = zip(*logvs)
481
- logps = np.array(logps)
482
- samples = np.array(samples)
483
- rv_dict[marginalized_rv.name] = samples.reshape(
484
- tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
485
- )
486
- else:
487
- logps = np.array(logvs)
443
+ marginalized_var_to_recover = marginalized_model[var_name_to_recover]
444
+ dependent_rvs = [marginalized_model[rv.name] for rv in dependent_rvs]
488
445
 
489
- rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
490
- tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
446
+ with warnings.catch_warnings():
447
+ warnings.filterwarnings("ignore", category=NonSeparableLogpWarning)
448
+ logps = marginalized_model.logp(
449
+ vars=[marginalized_var_to_recover, *dependent_rvs], sum=False
491
450
  )
492
- if marginalized_rv.name in m.named_vars_to_dims:
493
- rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
494
- rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
495
- "lp_" + marginalized_rv.name + "_dim"
496
- ]
497
-
498
- coords, dims = coords_and_dims_for_inferencedata(self)
499
- dims.update(rv_dims)
500
- rv_dataset = dict_to_dataset(
501
- rv_dict,
502
- library=pymc,
503
- dims=dims,
504
- coords=coords,
505
- default_dims=list(sample_dims),
506
- skip_event_dims=True,
507
- )
508
451
 
509
- if extend_inferencedata:
510
- idata.posterior = idata.posterior.assign(rv_dataset)
511
- return idata
512
- else:
513
- return rv_dataset
452
+ marginalized_logp, *dependent_logps = logps
453
+ joint_logp = marginalized_logp + reduce_batch_dependent_logps(
454
+ dependent_rvs_dim_connections,
455
+ [dependent_var.owner.op for dependent_var in dependent_rvs],
456
+ dependent_logps,
457
+ )
514
458
 
459
+ marginalized_value = marginalized_model.rvs_to_values[marginalized_var_to_recover]
460
+ other_values = [v for v in marginalized_model.value_vars if v is not marginalized_value]
461
+
462
+ rv_shape = constant_fold(tuple(var_to_recover.shape), raise_not_constant=False)
463
+ rv_domain = get_domain_of_finite_discrete_rv(var_to_recover)
464
+ rv_domain_tensor = pt.moveaxis(
465
+ pt.full(
466
+ (*rv_shape, len(rv_domain)),
467
+ rv_domain,
468
+ dtype=var_to_recover.dtype,
469
+ ),
470
+ -1,
471
+ 0,
472
+ )
515
473
 
516
- def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
517
- """Marginalize a subset of variables in a PyMC model.
474
+ batched_joint_logp = vectorize_graph(
475
+ joint_logp,
476
+ replace={marginalized_value: rv_domain_tensor},
477
+ )
478
+ batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1)
479
+
480
+ joint_logp_norm = log_softmax(batched_joint_logp, axis=-1)
481
+ if return_samples:
482
+ rv_draws = Categorical.dist(logit_p=batched_joint_logp)
483
+ if isinstance(var_to_recover.owner.op, DiscreteUniform):
484
+ rv_draws += rv_domain[0]
485
+ outputs = [joint_logp_norm, rv_draws]
486
+ else:
487
+ outputs = joint_logp_norm
518
488
 
519
- This creates a class of `MarginalModel` from an existing `Model`, with the specified
520
- variables marginalized.
489
+ rv_loglike_fn = compile_pymc(
490
+ inputs=other_values,
491
+ outputs=outputs,
492
+ on_unused_input="ignore",
493
+ random_seed=seed,
494
+ )
521
495
 
522
- See documentation for `MarginalModel` for more information.
496
+ logvs = [rv_loglike_fn(**vs) for vs in transformed_posterior_pts]
523
497
 
524
- Parameters
525
- ----------
526
- model : Model
527
- PyMC model to marginalize. Original variables well be cloned.
528
- rvs_to_marginalize : Sequence[TensorVariable]
529
- Variables to marginalize in the returned model.
498
+ if return_samples:
499
+ logps, samples = zip(*logvs)
500
+ logps = np.asarray(logps)
501
+ samples = np.asarray(samples)
502
+ rv_dict[var_name_to_recover] = samples.reshape(
503
+ tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
504
+ )
505
+ else:
506
+ logps = np.asarray(logvs)
530
507
 
531
- Returns
532
- -------
533
- marginal_model: MarginalModel
534
- Marginal model with the specified variables marginalized.
535
- """
536
- if not isinstance(rvs_to_marginalize, tuple | list):
537
- rvs_to_marginalize = (rvs_to_marginalize,)
538
- rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize]
508
+ rv_dict["lp_" + var_name_to_recover] = logps.reshape(
509
+ tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
510
+ )
511
+ if var_name_to_recover in unmarginal_model.named_vars_to_dims:
512
+ rv_dims[var_name_to_recover] = list(
513
+ unmarginal_model.named_vars_to_dims[var_name_to_recover]
514
+ )
515
+ rv_dims["lp_" + var_name_to_recover] = rv_dims[var_name_to_recover] + [
516
+ "lp_" + var_name_to_recover + "_dim"
517
+ ]
518
+
519
+ coords, dims = coords_and_dims_for_inferencedata(unmarginal_model)
520
+ dims.update(rv_dims)
521
+ rv_dataset = dict_to_dataset(
522
+ rv_dict,
523
+ library=pymc,
524
+ dims=dims,
525
+ coords=coords,
526
+ skip_event_dims=True,
527
+ )
539
528
 
540
- marginal_model = MarginalModel.from_model(model)
541
- marginal_model.marginalize(rvs_to_marginalize)
542
- return marginal_model
529
+ if extend_inferencedata:
530
+ idata.posterior = idata.posterior.assign(rv_dataset)
531
+ return idata
532
+ else:
533
+ return rv_dataset
543
534
 
544
535
 
545
536
  def collect_shared_vars(outputs, blockers):
546
537
  return [
547
- inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
538
+ inp
539
+ for inp in graph_inputs(outputs, blockers=blockers)
540
+ if (isinstance(inp, SharedVariable) and inp not in blockers)
548
541
  ]
549
542
 
550
543
 
551
- def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
552
- dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
553
- if not dependent_rvs:
554
- raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
544
+ def remove_model_vars(vars):
545
+ """Remove ModelVars from the graph of vars."""
546
+ model_vars = [var for var in vars if isinstance(var.owner.op, ModelValuedVar)]
547
+ replacements = [(model_var, model_var.owner.inputs[0]) for model_var in model_vars]
548
+ fgraph = FunctionGraph(outputs=vars, clone=False)
549
+ toposort_replace(fgraph, replacements)
550
+ return fgraph.outputs
555
551
 
556
- marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
557
- other_direct_rv_ancestors = [
558
- rv
559
- for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
560
- if rv is not rv_to_marginalize
561
- ]
562
552
 
553
+ def replace_finite_discrete_marginal_subgraph(
554
+ fgraph, rv_to_marginalize, dependent_rvs, input_rvs
555
+ ) -> None:
563
556
  # If the marginalized RV has multiple dimensions, check that graph between
564
557
  # marginalized RV and dependent RVs does not mix information from batch dimensions
565
558
  # (otherwise logp would require enumerating over all combinations of batch dimension values)
@@ -574,22 +567,42 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
574
567
  "You can try splitting the marginalized RV into separate components and marginalizing them separately."
575
568
  ) from e
576
569
 
577
- input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)))
578
570
  output_rvs = [rv_to_marginalize, *dependent_rvs]
571
+ rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)
572
+ outputs = output_rvs + list(rng_updates.values())
573
+ inputs = input_rvs + list(rng_updates.keys())
574
+ # Add any other shared variable inputs
575
+ inputs += collect_shared_vars(output_rvs, blockers=inputs)
579
576
 
580
- # We are strict about shared variables in SymbolicRandomVariables
581
- inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs)
577
+ inner_inputs = [inp.clone() for inp in inputs]
578
+ inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))
579
+ inner_outputs = remove_model_vars(inner_outputs)
582
580
 
583
- if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
581
+ if isinstance(inner_outputs[0].owner.op, DiscreteMarkovChain):
584
582
  marginalize_constructor = MarginalDiscreteMarkovChainRV
585
583
  else:
586
584
  marginalize_constructor = MarginalFiniteDiscreteRV
587
585
 
586
+ _, _, *dims = rv_to_marginalize.owner.inputs
588
587
  marginalization_op = marginalize_constructor(
589
- inputs=inputs,
590
- outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph
588
+ inputs=inner_inputs,
589
+ outputs=inner_outputs,
591
590
  dims_connections=dependent_rvs_dim_connections,
591
+ dims=dims,
592
592
  )
593
- new_output_rvs = marginalization_op(*inputs)
594
- fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs)))
595
- return output_rvs, new_output_rvs
593
+
594
+ new_outputs = marginalization_op(*inputs)
595
+ for old_output, new_output in zip(outputs, new_outputs):
596
+ new_output.name = old_output.name
597
+
598
+ model_replacements = []
599
+ for old_output, new_output in zip(outputs, new_outputs):
600
+ if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar):
601
+ # Replace the marginalized ModelFreeRV (or non model-variables) themselves
602
+ var_to_replace = old_output
603
+ else:
604
+ # Replace the underlying RV, keeping the same value, transform and dims
605
+ var_to_replace = old_output.owner.inputs[0]
606
+ model_replacements.append((var_to_replace, new_output))
607
+
608
+ fgraph.replace_all(model_replacements)