pymc-extras 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/distributions/timeseries.py +1 -1
- pymc_extras/model/marginal/distributions.py +100 -3
- pymc_extras/model/marginal/graph_analysis.py +8 -9
- pymc_extras/model/marginal/marginal_model.py +437 -424
- pymc_extras/statespace/models/structural.py +21 -6
- pymc_extras/utils/model_equivalence.py +66 -0
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/METADATA +3 -4
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/RECORD +18 -17
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/WHEEL +1 -1
- tests/model/marginal/test_distributions.py +12 -11
- tests/model/marginal/test_marginal_model.py +301 -201
- tests/statespace/test_structural.py +10 -3
- tests/test_pivoted_cholesky.py +1 -1
- tests/utils.py +0 -31
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Union
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import pymc
|
|
@@ -13,21 +12,42 @@ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
|
|
13
12
|
from pymc.distributions.transforms import Chain
|
|
14
13
|
from pymc.logprob.transforms import IntervalTransform
|
|
15
14
|
from pymc.model import Model
|
|
16
|
-
from pymc.
|
|
17
|
-
|
|
15
|
+
from pymc.model.fgraph import (
|
|
16
|
+
ModelFreeRV,
|
|
17
|
+
ModelValuedVar,
|
|
18
|
+
fgraph_from_model,
|
|
19
|
+
model_free_rv,
|
|
20
|
+
model_from_fgraph,
|
|
21
|
+
)
|
|
22
|
+
from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
|
|
23
|
+
from pymc.util import RandomState, _get_seeds_per_chain
|
|
24
|
+
from pytensor import In, Out
|
|
18
25
|
from pytensor.compile import SharedVariable
|
|
19
|
-
from pytensor.graph import
|
|
20
|
-
|
|
26
|
+
from pytensor.graph import (
|
|
27
|
+
FunctionGraph,
|
|
28
|
+
Variable,
|
|
29
|
+
clone_replace,
|
|
30
|
+
graph_inputs,
|
|
31
|
+
graph_replace,
|
|
32
|
+
node_rewriter,
|
|
33
|
+
vectorize_graph,
|
|
34
|
+
)
|
|
35
|
+
from pytensor.graph.rewriting.basic import in2out
|
|
21
36
|
from pytensor.tensor import TensorVariable
|
|
22
|
-
from pytensor.tensor.special import log_softmax
|
|
23
37
|
|
|
24
38
|
__all__ = ["MarginalModel", "marginalize"]
|
|
25
39
|
|
|
40
|
+
from pytensor.tensor.random.type import RandomType
|
|
41
|
+
from pytensor.tensor.special import log_softmax
|
|
42
|
+
|
|
26
43
|
from pymc_extras.distributions import DiscreteMarkovChain
|
|
27
44
|
from pymc_extras.model.marginal.distributions import (
|
|
28
45
|
MarginalDiscreteMarkovChainRV,
|
|
29
46
|
MarginalFiniteDiscreteRV,
|
|
47
|
+
MarginalRV,
|
|
48
|
+
NonSeparableLogpWarning,
|
|
30
49
|
get_domain_of_finite_discrete_rv,
|
|
50
|
+
inline_ofg_outputs,
|
|
31
51
|
reduce_batch_dependent_logps,
|
|
32
52
|
)
|
|
33
53
|
from pymc_extras.model.marginal.graph_analysis import (
|
|
@@ -87,479 +107,452 @@ class MarginalModel(Model):
|
|
|
87
107
|
"""
|
|
88
108
|
|
|
89
109
|
def __init__(self, *args, **kwargs):
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
110
|
+
raise TypeError(
|
|
111
|
+
"MarginalModel was deprecated in favor of `marginalize` which now returns a PyMC model"
|
|
112
|
+
)
|
|
93
113
|
|
|
94
|
-
def _delete_rv_mappings(self, rv: TensorVariable) -> None:
|
|
95
|
-
"""Remove all model mappings referring to rv
|
|
96
114
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
115
|
+
def _warn_interval_transform(rv_to_marginalize, replaced_vars: Sequence[ModelValuedVar]) -> None:
|
|
116
|
+
for replaced_var in replaced_vars:
|
|
117
|
+
if not isinstance(replaced_var.owner.op, ModelValuedVar):
|
|
118
|
+
raise TypeError(f"{replaced_var} is not a ModelValuedVar")
|
|
100
119
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
if name in self.named_vars_to_dims:
|
|
104
|
-
self.named_vars_to_dims.pop(name)
|
|
120
|
+
if not isinstance(replaced_var.owner.op, ModelFreeRV):
|
|
121
|
+
continue
|
|
105
122
|
|
|
106
|
-
|
|
107
|
-
|
|
123
|
+
if replaced_var is rv_to_marginalize:
|
|
124
|
+
continue
|
|
108
125
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
self.observed_RVs.remove(rv)
|
|
115
|
-
|
|
116
|
-
def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None:
|
|
117
|
-
"""Transfer model mappings from old_rv to new_rv"""
|
|
118
|
-
|
|
119
|
-
assert old_rv in self.basic_RVs, "old_rv is not part of the Model"
|
|
120
|
-
assert new_rv not in self.basic_RVs, "new_rv is already part of the Model"
|
|
121
|
-
|
|
122
|
-
self.named_vars.pop(old_rv.name)
|
|
123
|
-
new_rv.name = old_rv.name
|
|
124
|
-
self.named_vars[new_rv.name] = new_rv
|
|
125
|
-
if old_rv in self.named_vars_to_dims:
|
|
126
|
-
self._RV_dims[new_rv] = self._RV_dims.pop(old_rv)
|
|
127
|
-
|
|
128
|
-
value = self.rvs_to_values.pop(old_rv)
|
|
129
|
-
self.rvs_to_values[new_rv] = value
|
|
130
|
-
self.values_to_rvs[value] = new_rv
|
|
131
|
-
|
|
132
|
-
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
|
|
133
|
-
if old_rv in self.free_RVs:
|
|
134
|
-
index = self.free_RVs.index(old_rv)
|
|
135
|
-
self.free_RVs.pop(index)
|
|
136
|
-
self.free_RVs.insert(index, new_rv)
|
|
137
|
-
self.rvs_to_initial_values[new_rv] = self.rvs_to_initial_values.pop(old_rv)
|
|
138
|
-
elif old_rv in self.observed_RVs:
|
|
139
|
-
index = self.observed_RVs.index(old_rv)
|
|
140
|
-
self.observed_RVs.pop(index)
|
|
141
|
-
self.observed_RVs.insert(index, new_rv)
|
|
142
|
-
|
|
143
|
-
def _marginalize(self, user_warnings=False):
|
|
144
|
-
fg = FunctionGraph(outputs=self.basic_RVs + self.marginalized_rvs, clone=False)
|
|
145
|
-
|
|
146
|
-
toposort = fg.toposort()
|
|
147
|
-
rvs_left_to_marginalize = self.marginalized_rvs
|
|
148
|
-
for rv_to_marginalize in sorted(
|
|
149
|
-
self.marginalized_rvs,
|
|
150
|
-
key=lambda rv: toposort.index(rv.owner),
|
|
151
|
-
reverse=True,
|
|
126
|
+
transform = replaced_var.owner.op.transform
|
|
127
|
+
|
|
128
|
+
if isinstance(transform, IntervalTransform) or (
|
|
129
|
+
isinstance(transform, Chain)
|
|
130
|
+
and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)
|
|
152
131
|
):
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
if
|
|
156
|
-
|
|
157
|
-
):
|
|
158
|
-
raise NotImplementedError(
|
|
159
|
-
f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}"
|
|
160
|
-
)
|
|
161
|
-
for pot in self.potentials:
|
|
162
|
-
if is_conditional_dependent(
|
|
163
|
-
pot, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
|
|
164
|
-
):
|
|
165
|
-
raise NotImplementedError(
|
|
166
|
-
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
|
|
170
|
-
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
|
|
132
|
+
warnings.warn(
|
|
133
|
+
f"The transform {transform} for the variable {replaced_var}, which depends on the "
|
|
134
|
+
f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
|
|
135
|
+
UserWarning,
|
|
171
136
|
)
|
|
172
137
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
138
|
+
|
|
139
|
+
def _unique(seq: Sequence) -> list:
|
|
140
|
+
"""Copied from https://stackoverflow.com/a/480227"""
|
|
141
|
+
seen = set()
|
|
142
|
+
seen_add = seen.add
|
|
143
|
+
return [x for x in seq if not (x in seen or seen_add(x))]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
|
|
147
|
+
"""Marginalize a subset of variables in a PyMC model.
|
|
148
|
+
|
|
149
|
+
This creates a class of `MarginalModel` from an existing `Model`, with the specified
|
|
150
|
+
variables marginalized.
|
|
151
|
+
|
|
152
|
+
See documentation for `MarginalModel` for more information.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
model : Model
|
|
157
|
+
PyMC model to marginalize. Original variables well be cloned.
|
|
158
|
+
rvs_to_marginalize : Sequence[TensorVariable]
|
|
159
|
+
Variables to marginalize in the returned model.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
marginal_model: MarginalModel
|
|
164
|
+
Marginal model with the specified variables marginalized.
|
|
165
|
+
"""
|
|
166
|
+
if isinstance(rvs_to_marginalize, str | Variable):
|
|
167
|
+
rvs_to_marginalize = (rvs_to_marginalize,)
|
|
168
|
+
|
|
169
|
+
rvs_to_marginalize = [model[rv] if isinstance(rv, str) else rv for rv in rvs_to_marginalize]
|
|
170
|
+
|
|
171
|
+
if not rvs_to_marginalize:
|
|
172
|
+
return model
|
|
173
|
+
|
|
174
|
+
for rv_to_marginalize in rvs_to_marginalize:
|
|
175
|
+
if rv_to_marginalize not in model.free_RVs:
|
|
176
|
+
raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model")
|
|
177
|
+
|
|
178
|
+
rv_op = rv_to_marginalize.owner.op
|
|
179
|
+
if isinstance(rv_op, DiscreteMarkovChain):
|
|
180
|
+
if rv_op.n_lags > 1:
|
|
181
|
+
raise NotImplementedError(
|
|
182
|
+
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
|
|
183
|
+
)
|
|
184
|
+
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
|
|
185
|
+
raise NotImplementedError(
|
|
186
|
+
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
|
|
178
187
|
)
|
|
188
|
+
elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
|
|
191
|
+
)
|
|
179
192
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
if old_rv in self.basic_RVs:
|
|
189
|
-
self._transfer_rv_mappings(old_rv, new_rv)
|
|
190
|
-
if user_warnings:
|
|
191
|
-
# Interval transforms for dependent variable won't work for non-constant bounds because
|
|
192
|
-
# the RV inputs are now different and may depend on another RV that also depends on the
|
|
193
|
-
# same marginalized RV
|
|
194
|
-
transform = self.rvs_to_transforms[new_rv]
|
|
195
|
-
if isinstance(transform, IntervalTransform) or (
|
|
196
|
-
isinstance(transform, Chain)
|
|
197
|
-
and any(
|
|
198
|
-
isinstance(tr, IntervalTransform) for tr in transform.transform_list
|
|
199
|
-
)
|
|
200
|
-
):
|
|
201
|
-
warnings.warn(
|
|
202
|
-
f"The transform {transform} for the variable {old_rv}, which depends on the "
|
|
203
|
-
f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
|
|
204
|
-
UserWarning,
|
|
205
|
-
)
|
|
206
|
-
return self
|
|
207
|
-
|
|
208
|
-
def _logp(self, *args, **kwargs):
|
|
209
|
-
return super().logp(*args, **kwargs)
|
|
210
|
-
|
|
211
|
-
def logp(self, vars=None, **kwargs):
|
|
212
|
-
m = self.clone()._marginalize()
|
|
213
|
-
if vars is not None:
|
|
214
|
-
if not isinstance(vars, Sequence):
|
|
215
|
-
vars = (vars,)
|
|
216
|
-
vars = [m[var.name] for var in vars]
|
|
217
|
-
return m._logp(vars=vars, **kwargs)
|
|
218
|
-
|
|
219
|
-
@staticmethod
|
|
220
|
-
def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel":
|
|
221
|
-
new_model = MarginalModel(coords=model.coords)
|
|
222
|
-
if isinstance(model, MarginalModel):
|
|
223
|
-
marginalized_rvs = model.marginalized_rvs
|
|
224
|
-
marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
|
|
225
|
-
else:
|
|
226
|
-
marginalized_rvs = []
|
|
227
|
-
marginalized_named_vars_to_dims = {}
|
|
228
|
-
|
|
229
|
-
model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
|
|
230
|
-
data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
|
|
231
|
-
vars = model_vars + data_vars
|
|
232
|
-
cloned_vars = clone_replace(vars)
|
|
233
|
-
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
|
|
234
|
-
new_model.vars_to_clone = vars_to_clone
|
|
235
|
-
|
|
236
|
-
new_model.named_vars = treedict(
|
|
237
|
-
{name: vars_to_clone[var] for name, var in model.named_vars.items()}
|
|
238
|
-
)
|
|
239
|
-
new_model.named_vars_to_dims = model.named_vars_to_dims
|
|
240
|
-
new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()}
|
|
241
|
-
new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()}
|
|
242
|
-
new_model.rvs_to_transforms = {
|
|
243
|
-
vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items()
|
|
244
|
-
}
|
|
245
|
-
new_model.rvs_to_initial_values = {
|
|
246
|
-
vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items()
|
|
247
|
-
}
|
|
248
|
-
new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
|
|
249
|
-
new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
|
|
250
|
-
new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
|
|
251
|
-
new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]
|
|
252
|
-
|
|
253
|
-
new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
|
|
254
|
-
new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
|
|
255
|
-
return new_model
|
|
256
|
-
|
|
257
|
-
def clone(self):
|
|
258
|
-
return self.from_model(self)
|
|
259
|
-
|
|
260
|
-
def marginalize(
|
|
261
|
-
self,
|
|
262
|
-
rvs_to_marginalize: ModelRVs,
|
|
193
|
+
fg, memo = fgraph_from_model(model)
|
|
194
|
+
rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]
|
|
195
|
+
toposort = fg.toposort()
|
|
196
|
+
|
|
197
|
+
for rv_to_marginalize in sorted(
|
|
198
|
+
rvs_to_marginalize,
|
|
199
|
+
key=lambda rv: toposort.index(rv.owner),
|
|
200
|
+
reverse=True,
|
|
263
201
|
):
|
|
264
|
-
if
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
202
|
+
all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)]
|
|
203
|
+
|
|
204
|
+
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
|
|
205
|
+
if not dependent_rvs:
|
|
206
|
+
# TODO: This should at most be a warning, not an error
|
|
207
|
+
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
|
|
208
|
+
|
|
209
|
+
# Issue warning for IntervalTransform on dependent RVs
|
|
210
|
+
for dependent_rv in dependent_rvs:
|
|
211
|
+
transform = dependent_rv.owner.op.transform
|
|
269
212
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
213
|
+
if isinstance(transform, IntervalTransform) or (
|
|
214
|
+
isinstance(transform, Chain)
|
|
215
|
+
and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)
|
|
216
|
+
):
|
|
217
|
+
warnings.warn(
|
|
218
|
+
f"The transform {transform} for the variable {dependent_rv}, which depends on the "
|
|
219
|
+
f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
|
|
220
|
+
UserWarning,
|
|
274
221
|
)
|
|
275
222
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
raise NotImplementedError(
|
|
280
|
-
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
|
|
281
|
-
)
|
|
282
|
-
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
|
|
283
|
-
raise NotImplementedError(
|
|
284
|
-
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
|
|
285
|
-
)
|
|
286
|
-
elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
|
|
223
|
+
# Check that no deterministics or potentials depend on the rv to marginalize
|
|
224
|
+
for det in model.deterministics:
|
|
225
|
+
if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs):
|
|
287
226
|
raise NotImplementedError(
|
|
288
|
-
f"
|
|
227
|
+
f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}"
|
|
228
|
+
)
|
|
229
|
+
for pot in model.potentials:
|
|
230
|
+
if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs):
|
|
231
|
+
raise NotImplementedError(
|
|
232
|
+
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
|
|
289
233
|
)
|
|
290
234
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
# Raise errors and warnings immediately
|
|
299
|
-
self.clone()._marginalize(user_warnings=True)
|
|
300
|
-
|
|
301
|
-
def _to_transformed(self):
|
|
302
|
-
"""Create a function from the untransformed space to the transformed space"""
|
|
303
|
-
transformed_rvs = []
|
|
304
|
-
transformed_names = []
|
|
305
|
-
|
|
306
|
-
for rv in self.free_RVs:
|
|
307
|
-
transform = self.rvs_to_transforms.get(rv)
|
|
308
|
-
if transform is None:
|
|
309
|
-
transformed_rvs.append(rv)
|
|
310
|
-
transformed_names.append(rv.name)
|
|
311
|
-
else:
|
|
312
|
-
transformed_rv = transform.forward(rv, *rv.owner.inputs)
|
|
313
|
-
transformed_rvs.append(transformed_rv)
|
|
314
|
-
transformed_names.append(self.rvs_to_values[rv].name)
|
|
315
|
-
|
|
316
|
-
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
|
|
317
|
-
return fn, transformed_names
|
|
318
|
-
|
|
319
|
-
def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]):
|
|
320
|
-
for rv in rvs_to_unmarginalize:
|
|
321
|
-
if isinstance(rv, str):
|
|
322
|
-
rv = self[rv]
|
|
323
|
-
self.marginalized_rvs.remove(rv)
|
|
324
|
-
if rv.name in self._marginalized_named_vars_to_dims:
|
|
325
|
-
dims = self._marginalized_named_vars_to_dims.pop(rv.name)
|
|
326
|
-
else:
|
|
327
|
-
dims = None
|
|
328
|
-
self.register_rv(rv, name=rv.name, dims=dims)
|
|
329
|
-
|
|
330
|
-
def recover_marginals(
|
|
331
|
-
self,
|
|
332
|
-
idata: InferenceData,
|
|
333
|
-
var_names: Sequence[str] | None = None,
|
|
334
|
-
return_samples: bool = True,
|
|
335
|
-
extend_inferencedata: bool = True,
|
|
336
|
-
random_seed: RandomState = None,
|
|
337
|
-
):
|
|
338
|
-
"""Computes posterior log-probabilities and samples of marginalized variables
|
|
339
|
-
conditioned on parameters of the model given InferenceData with posterior group
|
|
235
|
+
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
|
|
236
|
+
other_direct_rv_ancestors = [
|
|
237
|
+
rv
|
|
238
|
+
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
|
|
239
|
+
if rv is not rv_to_marginalize
|
|
240
|
+
]
|
|
241
|
+
input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
|
|
340
242
|
|
|
341
|
-
|
|
342
|
-
conditioned on both the parameters and the other variables still marginalized
|
|
243
|
+
replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
|
|
343
244
|
|
|
344
|
-
|
|
245
|
+
return model_from_fgraph(fg, mutate_fgraph=True)
|
|
345
246
|
|
|
346
|
-
Parameters
|
|
347
|
-
----------
|
|
348
|
-
idata : InferenceData
|
|
349
|
-
InferenceData with posterior group
|
|
350
|
-
var_names : sequence of str, optional
|
|
351
|
-
List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
|
|
352
|
-
return_samples : bool, default True
|
|
353
|
-
If True, also return samples of the marginalized variables
|
|
354
|
-
extend_inferencedata : bool, default True
|
|
355
|
-
Whether to extend the original InferenceData or return a new one
|
|
356
|
-
random_seed: int, array-like of int or SeedSequence, optional
|
|
357
|
-
Seed used to generating samples
|
|
358
247
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
248
|
+
@node_rewriter(tracks=[MarginalRV])
|
|
249
|
+
def local_unmarginalize(fgraph, node):
|
|
250
|
+
unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs)
|
|
251
|
+
rngs = [rng for rng in dependent_rvs_and_rngs if isinstance(rng.type, RandomType)]
|
|
252
|
+
dependent_rvs = [rv for rv in dependent_rvs_and_rngs if rv not in rngs]
|
|
363
253
|
|
|
364
|
-
|
|
254
|
+
# Wrap the marginalized RV in a FreeRV
|
|
255
|
+
# TODO: Preserve dims and transform in MarginalRV
|
|
256
|
+
value = unmarginalized_rv.clone()
|
|
257
|
+
fgraph.add_input(value)
|
|
258
|
+
transform = None
|
|
259
|
+
unmarginalized_free_rv = model_free_rv(unmarginalized_rv, value, transform, *node.op.dims)
|
|
365
260
|
|
|
366
|
-
|
|
367
|
-
|
|
261
|
+
# Replace references to the marginalized RV with the FreeRV in the dependent RVs
|
|
262
|
+
dependent_rvs = graph_replace(dependent_rvs, {unmarginalized_rv: unmarginalized_free_rv})
|
|
368
263
|
|
|
369
|
-
|
|
370
|
-
p = pm.Beta("p", 1, 1)
|
|
371
|
-
x = pm.Bernoulli("x", p=p, shape=(3,))
|
|
372
|
-
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
|
|
264
|
+
return [unmarginalized_free_rv, *dependent_rvs, *rngs]
|
|
373
265
|
|
|
374
|
-
m.marginalize([x])
|
|
375
266
|
|
|
376
|
-
|
|
377
|
-
m.recover_marginals(idata, var_names=["x"])
|
|
267
|
+
unmarginalize_rewrite = in2out(local_unmarginalize, ignore_newtrees=False)
|
|
378
268
|
|
|
379
269
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
var_names = [var.name for var in self.marginalized_rvs]
|
|
270
|
+
def unmarginalize(model: Model, rvs_to_unmarginalize: str | Sequence[str] | None = None) -> Model:
|
|
271
|
+
"""Unmarginalize a subset of variables in a PyMC model.
|
|
383
272
|
|
|
384
|
-
var_names = [var if isinstance(var, str) else var.name for var in var_names]
|
|
385
|
-
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
|
|
386
|
-
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
|
|
387
|
-
if missing_names:
|
|
388
|
-
raise ValueError(f"Unrecognized var_names: {missing_names}")
|
|
389
273
|
|
|
390
|
-
|
|
391
|
-
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
model : Model
|
|
277
|
+
PyMC model to unmarginalize. Original variables well be cloned.
|
|
278
|
+
rvs_to_unmarginalize : str or sequence of str, optional
|
|
279
|
+
Variables to unmarginalize in the returned model. If None, all variables are
|
|
280
|
+
unmarginalized.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
unmarginal_model: Model
|
|
285
|
+
Model with the specified variables unmarginalized.
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
# Unmarginalize all the MarginalRVs
|
|
289
|
+
fg, memo = fgraph_from_model(model)
|
|
290
|
+
unmarginalize_rewrite(fg)
|
|
291
|
+
unmarginalized_model = model_from_fgraph(fg, mutate_fgraph=True)
|
|
292
|
+
if rvs_to_unmarginalize is None:
|
|
293
|
+
return unmarginalized_model
|
|
294
|
+
|
|
295
|
+
# Re-marginalize the variables we want to keep marginalized
|
|
296
|
+
if not isinstance(rvs_to_unmarginalize, list | tuple):
|
|
297
|
+
rvs_to_unmarginalize = (rvs_to_unmarginalize,)
|
|
298
|
+
rvs_to_unmarginalize = set(rvs_to_unmarginalize)
|
|
299
|
+
|
|
300
|
+
old_free_rv_names = set(rv.name for rv in model.free_RVs)
|
|
301
|
+
new_free_rv_names = set(
|
|
302
|
+
rv.name for rv in unmarginalized_model.free_RVs if rv.name not in old_free_rv_names
|
|
303
|
+
)
|
|
304
|
+
if rvs_to_unmarginalize - new_free_rv_names:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Unrecognized rvs_to_unmarginalize: {rvs_to_unmarginalize - new_free_rv_names}"
|
|
307
|
+
)
|
|
308
|
+
rvs_to_keep_marginalized = tuple(new_free_rv_names - rvs_to_unmarginalize)
|
|
309
|
+
return marginalize(unmarginalized_model, rvs_to_keep_marginalized)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def transform_posterior_pts(model, posterior_pts):
|
|
313
|
+
"""Create a function from the untransformed space to the transformed space"""
|
|
314
|
+
# TODO: This should be a utility in PyMC
|
|
315
|
+
transformed_rvs = []
|
|
316
|
+
transformed_names = []
|
|
317
|
+
|
|
318
|
+
for rv in model.free_RVs:
|
|
319
|
+
transform = model.rvs_to_transforms.get(rv)
|
|
320
|
+
if transform is None:
|
|
321
|
+
transformed_rvs.append(rv)
|
|
322
|
+
transformed_names.append(rv.name)
|
|
392
323
|
else:
|
|
393
|
-
|
|
324
|
+
transformed_rv = transform.forward(rv, *rv.owner.inputs)
|
|
325
|
+
transformed_rvs.append(transformed_rv)
|
|
326
|
+
transformed_names.append(model.rvs_to_values[rv].name)
|
|
394
327
|
|
|
395
|
-
|
|
328
|
+
fn = compile_pymc(
|
|
329
|
+
inputs=[In(inp, borrow=True) for inp in model.free_RVs],
|
|
330
|
+
outputs=[Out(out, borrow=True) for out in transformed_rvs],
|
|
331
|
+
)
|
|
332
|
+
fn.trust_input = True
|
|
396
333
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
|
|
400
|
-
]
|
|
334
|
+
# TODO: This should work with vectorized inputs
|
|
335
|
+
return [dict(zip(transformed_names, fn(**point))) for point in posterior_pts]
|
|
401
336
|
|
|
402
|
-
sample_dims = ("chain", "draw")
|
|
403
|
-
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
|
|
404
337
|
|
|
405
|
-
|
|
406
|
-
|
|
338
|
+
def recover_marginals(
|
|
339
|
+
model: Model,
|
|
340
|
+
idata: InferenceData,
|
|
341
|
+
var_names: Sequence[str] | None = None,
|
|
342
|
+
return_samples: bool = True,
|
|
343
|
+
extend_inferencedata: bool = True,
|
|
344
|
+
random_seed: RandomState = None,
|
|
345
|
+
):
|
|
346
|
+
"""Computes posterior log-probabilities and samples of marginalized variables
|
|
347
|
+
conditioned on parameters of the model given InferenceData with posterior group
|
|
407
348
|
|
|
408
|
-
|
|
409
|
-
|
|
349
|
+
When there are multiple marginalized variables, each marginalized variable is
|
|
350
|
+
conditioned on both the parameters and the other variables still marginalized
|
|
410
351
|
|
|
411
|
-
|
|
352
|
+
All log-probabilities are within the transformed space
|
|
412
353
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
354
|
+
Parameters
|
|
355
|
+
----------
|
|
356
|
+
model: Model
|
|
357
|
+
PyMC model with marginalized variables to recover
|
|
358
|
+
idata : InferenceData
|
|
359
|
+
InferenceData with posterior group
|
|
360
|
+
var_names : sequence of str, optional
|
|
361
|
+
List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
|
|
362
|
+
return_samples : bool, default True
|
|
363
|
+
If True, also return samples of the marginalized variables
|
|
364
|
+
extend_inferencedata : bool, default True
|
|
365
|
+
Whether to extend the original InferenceData or return a new one
|
|
366
|
+
random_seed: int, array-like of int or SeedSequence, optional
|
|
367
|
+
Seed used to generating samples
|
|
422
368
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False)
|
|
369
|
+
Returns
|
|
370
|
+
-------
|
|
371
|
+
idata : InferenceData
|
|
372
|
+
InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
|
|
428
373
|
|
|
429
|
-
|
|
430
|
-
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
|
|
431
|
-
marginalized_rv, dependent_rvs
|
|
432
|
-
)
|
|
433
|
-
marginalized_logp, *dependent_logps = logps
|
|
434
|
-
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
|
|
435
|
-
dependent_rvs_dim_connections,
|
|
436
|
-
[dependent_var.owner.op for dependent_var in dependent_rvs],
|
|
437
|
-
dependent_logps,
|
|
438
|
-
)
|
|
374
|
+
.. code-block:: python
|
|
439
375
|
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
|
|
444
|
-
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
|
|
445
|
-
rv_domain_tensor = pt.moveaxis(
|
|
446
|
-
pt.full(
|
|
447
|
-
(*rv_shape, len(rv_domain)),
|
|
448
|
-
rv_domain,
|
|
449
|
-
dtype=marginalized_rv.dtype,
|
|
450
|
-
),
|
|
451
|
-
-1,
|
|
452
|
-
0,
|
|
453
|
-
)
|
|
376
|
+
import pymc as pm
|
|
377
|
+
from pymc_extras import MarginalModel
|
|
454
378
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
)
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
379
|
+
with MarginalModel() as m:
|
|
380
|
+
p = pm.Beta("p", 1, 1)
|
|
381
|
+
x = pm.Bernoulli("x", p=p, shape=(3,))
|
|
382
|
+
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
|
|
383
|
+
|
|
384
|
+
m.marginalize([x])
|
|
385
|
+
|
|
386
|
+
idata = pm.sample()
|
|
387
|
+
m.recover_marginals(idata, var_names=["x"])
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
"""
|
|
391
|
+
unmarginal_model = unmarginalize(model)
|
|
392
|
+
|
|
393
|
+
# Find the names of the marginalized variables
|
|
394
|
+
model_var_names = set(rv.name for rv in model.free_RVs)
|
|
395
|
+
marginalized_rv_names = [
|
|
396
|
+
rv.name for rv in unmarginal_model.free_RVs if rv.name not in model_var_names
|
|
397
|
+
]
|
|
398
|
+
|
|
399
|
+
if var_names is None:
|
|
400
|
+
var_names = marginalized_rv_names
|
|
401
|
+
|
|
402
|
+
var_names = [var if isinstance(var, str) else var.name for var in var_names]
|
|
403
|
+
var_names_to_recover = [name for name in marginalized_rv_names if name in var_names]
|
|
404
|
+
missing_names = [name for name in var_names_to_recover if name not in marginalized_rv_names]
|
|
405
|
+
if missing_names:
|
|
406
|
+
raise ValueError(f"Unrecognized var_names: {missing_names}")
|
|
407
|
+
|
|
408
|
+
if return_samples and random_seed is not None:
|
|
409
|
+
seeds = _get_seeds_per_chain(random_seed, len(var_names_to_recover))
|
|
410
|
+
else:
|
|
411
|
+
seeds = [None] * len(var_names_to_recover)
|
|
412
|
+
|
|
413
|
+
posterior_pts, stacked_dims = dataset_to_point_list(
|
|
414
|
+
# Remove Deterministics
|
|
415
|
+
idata.posterior[[rv.name for rv in model.free_RVs]],
|
|
416
|
+
sample_dims=("chain", "draw"),
|
|
417
|
+
)
|
|
418
|
+
transformed_posterior_pts = transform_posterior_pts(model, posterior_pts)
|
|
419
|
+
|
|
420
|
+
rv_dict = {}
|
|
421
|
+
rv_dims = {}
|
|
422
|
+
for seed, var_name_to_recover in zip(seeds, var_names_to_recover):
|
|
423
|
+
var_to_recover = unmarginal_model[var_name_to_recover]
|
|
424
|
+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
|
|
425
|
+
if not isinstance(var_to_recover.owner.op, supported_dists):
|
|
426
|
+
raise NotImplementedError(
|
|
427
|
+
f"RV with distribution {var_to_recover.owner.op} cannot be recovered. "
|
|
428
|
+
f"Supported distribution include {supported_dists}"
|
|
475
429
|
)
|
|
476
430
|
|
|
477
|
-
|
|
431
|
+
other_marginalized_rvs_names = marginalized_rv_names.copy()
|
|
432
|
+
other_marginalized_rvs_names.remove(var_name_to_recover)
|
|
433
|
+
dependent_rvs = [
|
|
434
|
+
rv
|
|
435
|
+
for rv in find_conditional_dependent_rvs(var_to_recover, unmarginal_model.basic_RVs)
|
|
436
|
+
if rv.name not in other_marginalized_rvs_names
|
|
437
|
+
]
|
|
438
|
+
# Handle batch dims for marginalized value and its dependent RVs
|
|
439
|
+
dependent_rvs_dim_connections = subgraph_batch_dim_connection(var_to_recover, dependent_rvs)
|
|
440
|
+
|
|
441
|
+
marginalized_model = marginalize(unmarginal_model, other_marginalized_rvs_names)
|
|
478
442
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
logps = np.array(logps)
|
|
482
|
-
samples = np.array(samples)
|
|
483
|
-
rv_dict[marginalized_rv.name] = samples.reshape(
|
|
484
|
-
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
|
|
485
|
-
)
|
|
486
|
-
else:
|
|
487
|
-
logps = np.array(logvs)
|
|
443
|
+
marginalized_var_to_recover = marginalized_model[var_name_to_recover]
|
|
444
|
+
dependent_rvs = [marginalized_model[rv.name] for rv in dependent_rvs]
|
|
488
445
|
|
|
489
|
-
|
|
490
|
-
|
|
446
|
+
with warnings.catch_warnings():
|
|
447
|
+
warnings.filterwarnings("ignore", category=NonSeparableLogpWarning)
|
|
448
|
+
logps = marginalized_model.logp(
|
|
449
|
+
vars=[marginalized_var_to_recover, *dependent_rvs], sum=False
|
|
491
450
|
)
|
|
492
|
-
if marginalized_rv.name in m.named_vars_to_dims:
|
|
493
|
-
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
|
|
494
|
-
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
|
|
495
|
-
"lp_" + marginalized_rv.name + "_dim"
|
|
496
|
-
]
|
|
497
|
-
|
|
498
|
-
coords, dims = coords_and_dims_for_inferencedata(self)
|
|
499
|
-
dims.update(rv_dims)
|
|
500
|
-
rv_dataset = dict_to_dataset(
|
|
501
|
-
rv_dict,
|
|
502
|
-
library=pymc,
|
|
503
|
-
dims=dims,
|
|
504
|
-
coords=coords,
|
|
505
|
-
default_dims=list(sample_dims),
|
|
506
|
-
skip_event_dims=True,
|
|
507
|
-
)
|
|
508
451
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
452
|
+
marginalized_logp, *dependent_logps = logps
|
|
453
|
+
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
|
|
454
|
+
dependent_rvs_dim_connections,
|
|
455
|
+
[dependent_var.owner.op for dependent_var in dependent_rvs],
|
|
456
|
+
dependent_logps,
|
|
457
|
+
)
|
|
514
458
|
|
|
459
|
+
marginalized_value = marginalized_model.rvs_to_values[marginalized_var_to_recover]
|
|
460
|
+
other_values = [v for v in marginalized_model.value_vars if v is not marginalized_value]
|
|
461
|
+
|
|
462
|
+
rv_shape = constant_fold(tuple(var_to_recover.shape), raise_not_constant=False)
|
|
463
|
+
rv_domain = get_domain_of_finite_discrete_rv(var_to_recover)
|
|
464
|
+
rv_domain_tensor = pt.moveaxis(
|
|
465
|
+
pt.full(
|
|
466
|
+
(*rv_shape, len(rv_domain)),
|
|
467
|
+
rv_domain,
|
|
468
|
+
dtype=var_to_recover.dtype,
|
|
469
|
+
),
|
|
470
|
+
-1,
|
|
471
|
+
0,
|
|
472
|
+
)
|
|
515
473
|
|
|
516
|
-
|
|
517
|
-
|
|
474
|
+
batched_joint_logp = vectorize_graph(
|
|
475
|
+
joint_logp,
|
|
476
|
+
replace={marginalized_value: rv_domain_tensor},
|
|
477
|
+
)
|
|
478
|
+
batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1)
|
|
479
|
+
|
|
480
|
+
joint_logp_norm = log_softmax(batched_joint_logp, axis=-1)
|
|
481
|
+
if return_samples:
|
|
482
|
+
rv_draws = Categorical.dist(logit_p=batched_joint_logp)
|
|
483
|
+
if isinstance(var_to_recover.owner.op, DiscreteUniform):
|
|
484
|
+
rv_draws += rv_domain[0]
|
|
485
|
+
outputs = [joint_logp_norm, rv_draws]
|
|
486
|
+
else:
|
|
487
|
+
outputs = joint_logp_norm
|
|
518
488
|
|
|
519
|
-
|
|
520
|
-
|
|
489
|
+
rv_loglike_fn = compile_pymc(
|
|
490
|
+
inputs=other_values,
|
|
491
|
+
outputs=outputs,
|
|
492
|
+
on_unused_input="ignore",
|
|
493
|
+
random_seed=seed,
|
|
494
|
+
)
|
|
521
495
|
|
|
522
|
-
|
|
496
|
+
logvs = [rv_loglike_fn(**vs) for vs in transformed_posterior_pts]
|
|
523
497
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
498
|
+
if return_samples:
|
|
499
|
+
logps, samples = zip(*logvs)
|
|
500
|
+
logps = np.asarray(logps)
|
|
501
|
+
samples = np.asarray(samples)
|
|
502
|
+
rv_dict[var_name_to_recover] = samples.reshape(
|
|
503
|
+
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
logps = np.asarray(logvs)
|
|
530
507
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
508
|
+
rv_dict["lp_" + var_name_to_recover] = logps.reshape(
|
|
509
|
+
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
|
|
510
|
+
)
|
|
511
|
+
if var_name_to_recover in unmarginal_model.named_vars_to_dims:
|
|
512
|
+
rv_dims[var_name_to_recover] = list(
|
|
513
|
+
unmarginal_model.named_vars_to_dims[var_name_to_recover]
|
|
514
|
+
)
|
|
515
|
+
rv_dims["lp_" + var_name_to_recover] = rv_dims[var_name_to_recover] + [
|
|
516
|
+
"lp_" + var_name_to_recover + "_dim"
|
|
517
|
+
]
|
|
518
|
+
|
|
519
|
+
coords, dims = coords_and_dims_for_inferencedata(unmarginal_model)
|
|
520
|
+
dims.update(rv_dims)
|
|
521
|
+
rv_dataset = dict_to_dataset(
|
|
522
|
+
rv_dict,
|
|
523
|
+
library=pymc,
|
|
524
|
+
dims=dims,
|
|
525
|
+
coords=coords,
|
|
526
|
+
skip_event_dims=True,
|
|
527
|
+
)
|
|
539
528
|
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
529
|
+
if extend_inferencedata:
|
|
530
|
+
idata.posterior = idata.posterior.assign(rv_dataset)
|
|
531
|
+
return idata
|
|
532
|
+
else:
|
|
533
|
+
return rv_dataset
|
|
543
534
|
|
|
544
535
|
|
|
545
536
|
def collect_shared_vars(outputs, blockers):
|
|
546
537
|
return [
|
|
547
|
-
inp
|
|
538
|
+
inp
|
|
539
|
+
for inp in graph_inputs(outputs, blockers=blockers)
|
|
540
|
+
if (isinstance(inp, SharedVariable) and inp not in blockers)
|
|
548
541
|
]
|
|
549
542
|
|
|
550
543
|
|
|
551
|
-
def
|
|
552
|
-
|
|
553
|
-
if
|
|
554
|
-
|
|
544
|
+
def remove_model_vars(vars):
|
|
545
|
+
"""Remove ModelVars from the graph of vars."""
|
|
546
|
+
model_vars = [var for var in vars if isinstance(var.owner.op, ModelValuedVar)]
|
|
547
|
+
replacements = [(model_var, model_var.owner.inputs[0]) for model_var in model_vars]
|
|
548
|
+
fgraph = FunctionGraph(outputs=vars, clone=False)
|
|
549
|
+
toposort_replace(fgraph, replacements)
|
|
550
|
+
return fgraph.outputs
|
|
555
551
|
|
|
556
|
-
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
|
|
557
|
-
other_direct_rv_ancestors = [
|
|
558
|
-
rv
|
|
559
|
-
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
|
|
560
|
-
if rv is not rv_to_marginalize
|
|
561
|
-
]
|
|
562
552
|
|
|
553
|
+
def replace_finite_discrete_marginal_subgraph(
|
|
554
|
+
fgraph, rv_to_marginalize, dependent_rvs, input_rvs
|
|
555
|
+
) -> None:
|
|
563
556
|
# If the marginalized RV has multiple dimensions, check that graph between
|
|
564
557
|
# marginalized RV and dependent RVs does not mix information from batch dimensions
|
|
565
558
|
# (otherwise logp would require enumerating over all combinations of batch dimension values)
|
|
@@ -574,22 +567,42 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
|
|
|
574
567
|
"You can try splitting the marginalized RV into separate components and marginalizing them separately."
|
|
575
568
|
) from e
|
|
576
569
|
|
|
577
|
-
input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)))
|
|
578
570
|
output_rvs = [rv_to_marginalize, *dependent_rvs]
|
|
571
|
+
rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)
|
|
572
|
+
outputs = output_rvs + list(rng_updates.values())
|
|
573
|
+
inputs = input_rvs + list(rng_updates.keys())
|
|
574
|
+
# Add any other shared variable inputs
|
|
575
|
+
inputs += collect_shared_vars(output_rvs, blockers=inputs)
|
|
579
576
|
|
|
580
|
-
|
|
581
|
-
|
|
577
|
+
inner_inputs = [inp.clone() for inp in inputs]
|
|
578
|
+
inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))
|
|
579
|
+
inner_outputs = remove_model_vars(inner_outputs)
|
|
582
580
|
|
|
583
|
-
if isinstance(
|
|
581
|
+
if isinstance(inner_outputs[0].owner.op, DiscreteMarkovChain):
|
|
584
582
|
marginalize_constructor = MarginalDiscreteMarkovChainRV
|
|
585
583
|
else:
|
|
586
584
|
marginalize_constructor = MarginalFiniteDiscreteRV
|
|
587
585
|
|
|
586
|
+
_, _, *dims = rv_to_marginalize.owner.inputs
|
|
588
587
|
marginalization_op = marginalize_constructor(
|
|
589
|
-
inputs=
|
|
590
|
-
outputs=
|
|
588
|
+
inputs=inner_inputs,
|
|
589
|
+
outputs=inner_outputs,
|
|
591
590
|
dims_connections=dependent_rvs_dim_connections,
|
|
591
|
+
dims=dims,
|
|
592
592
|
)
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
593
|
+
|
|
594
|
+
new_outputs = marginalization_op(*inputs)
|
|
595
|
+
for old_output, new_output in zip(outputs, new_outputs):
|
|
596
|
+
new_output.name = old_output.name
|
|
597
|
+
|
|
598
|
+
model_replacements = []
|
|
599
|
+
for old_output, new_output in zip(outputs, new_outputs):
|
|
600
|
+
if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar):
|
|
601
|
+
# Replace the marginalized ModelFreeRV (or non model-variables) themselves
|
|
602
|
+
var_to_replace = old_output
|
|
603
|
+
else:
|
|
604
|
+
# Replace the underlying RV, keeping the same value, transform and dims
|
|
605
|
+
var_to_replace = old_output.owner.inputs[0]
|
|
606
|
+
model_replacements.append((var_to_replace, new_output))
|
|
607
|
+
|
|
608
|
+
fgraph.replace_all(model_replacements)
|