pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pymc
|
|
8
|
+
import pytensor.tensor as pt
|
|
9
|
+
|
|
10
|
+
from arviz import InferenceData, dict_to_dataset
|
|
11
|
+
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
|
|
12
|
+
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
|
13
|
+
from pymc.distributions.transforms import Chain
|
|
14
|
+
from pymc.logprob.transforms import IntervalTransform
|
|
15
|
+
from pymc.model import Model
|
|
16
|
+
from pymc.pytensorf import compile_pymc, constant_fold
|
|
17
|
+
from pymc.util import RandomState, _get_seeds_per_chain, treedict
|
|
18
|
+
from pytensor.compile import SharedVariable
|
|
19
|
+
from pytensor.graph import FunctionGraph, clone_replace, graph_inputs
|
|
20
|
+
from pytensor.graph.replace import vectorize_graph
|
|
21
|
+
from pytensor.tensor import TensorVariable
|
|
22
|
+
from pytensor.tensor.special import log_softmax
|
|
23
|
+
|
|
24
|
+
__all__ = ["MarginalModel", "marginalize"]
|
|
25
|
+
|
|
26
|
+
from pymc_extras.distributions import DiscreteMarkovChain
|
|
27
|
+
from pymc_extras.model.marginal.distributions import (
|
|
28
|
+
MarginalDiscreteMarkovChainRV,
|
|
29
|
+
MarginalFiniteDiscreteRV,
|
|
30
|
+
get_domain_of_finite_discrete_rv,
|
|
31
|
+
reduce_batch_dependent_logps,
|
|
32
|
+
)
|
|
33
|
+
from pymc_extras.model.marginal.graph_analysis import (
|
|
34
|
+
find_conditional_dependent_rvs,
|
|
35
|
+
find_conditional_input_rvs,
|
|
36
|
+
is_conditional_dependent,
|
|
37
|
+
subgraph_batch_dim_connection,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MarginalModel(Model):
|
|
44
|
+
"""Subclass of PyMC Model that implements functionality for automatic
|
|
45
|
+
marginalization of variables in the logp transformation
|
|
46
|
+
|
|
47
|
+
After defining the full Model, the `marginalize` method can be used to indicate a
|
|
48
|
+
subset of variables that should be marginalized
|
|
49
|
+
|
|
50
|
+
Notes
|
|
51
|
+
-----
|
|
52
|
+
Marginalization functionality is still very restricted. Only finite discrete
|
|
53
|
+
variables can be marginalized. Deterministics and Potentials cannot be conditionally
|
|
54
|
+
dependent on the marginalized variables.
|
|
55
|
+
|
|
56
|
+
Furthermore, not all instances of such variables can be marginalized. If a variable
|
|
57
|
+
has batched dimensions, it is required that any conditionally dependent variables
|
|
58
|
+
use information from an individual batched dimension. In other words, the graph
|
|
59
|
+
connecting the marginalized variable(s) to the dependent variable(s) must be
|
|
60
|
+
composed strictly of Elemwise Operations. This is necessary to ensure an efficient
|
|
61
|
+
logprob graph can be generated. If you want to bypass this restriction you can
|
|
62
|
+
separate each dimension of the marginalized variable into the scalar components
|
|
63
|
+
and then stack them together. Note that such graphs will grow exponentially in the
|
|
64
|
+
number of marginalized variables.
|
|
65
|
+
|
|
66
|
+
For the same reason, it's not possible to marginalize RVs with multivariate
|
|
67
|
+
dependent RVs.
|
|
68
|
+
|
|
69
|
+
Examples
|
|
70
|
+
--------
|
|
71
|
+
Marginalize over a single variable
|
|
72
|
+
|
|
73
|
+
.. code-block:: python
|
|
74
|
+
|
|
75
|
+
import pymc as pm
|
|
76
|
+
from pymc_extras import MarginalModel
|
|
77
|
+
|
|
78
|
+
with MarginalModel() as m:
|
|
79
|
+
p = pm.Beta("p", 1, 1)
|
|
80
|
+
x = pm.Bernoulli("x", p=p, shape=(3,))
|
|
81
|
+
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
|
|
82
|
+
|
|
83
|
+
m.marginalize([x])
|
|
84
|
+
|
|
85
|
+
idata = pm.sample()
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, *args, **kwargs):
|
|
90
|
+
super().__init__(*args, **kwargs)
|
|
91
|
+
self.marginalized_rvs = []
|
|
92
|
+
self._marginalized_named_vars_to_dims = {}
|
|
93
|
+
|
|
94
|
+
def _delete_rv_mappings(self, rv: TensorVariable) -> None:
|
|
95
|
+
"""Remove all model mappings referring to rv
|
|
96
|
+
|
|
97
|
+
This can be used to "delete" an RV from a model
|
|
98
|
+
"""
|
|
99
|
+
assert rv in self.basic_RVs, "rv is not part of the Model"
|
|
100
|
+
|
|
101
|
+
name = rv.name
|
|
102
|
+
self.named_vars.pop(name)
|
|
103
|
+
if name in self.named_vars_to_dims:
|
|
104
|
+
self.named_vars_to_dims.pop(name)
|
|
105
|
+
|
|
106
|
+
value = self.rvs_to_values.pop(rv)
|
|
107
|
+
self.values_to_rvs.pop(value)
|
|
108
|
+
|
|
109
|
+
self.rvs_to_transforms.pop(rv)
|
|
110
|
+
if rv in self.free_RVs:
|
|
111
|
+
self.free_RVs.remove(rv)
|
|
112
|
+
self.rvs_to_initial_values.pop(rv)
|
|
113
|
+
else:
|
|
114
|
+
self.observed_RVs.remove(rv)
|
|
115
|
+
|
|
116
|
+
def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None:
|
|
117
|
+
"""Transfer model mappings from old_rv to new_rv"""
|
|
118
|
+
|
|
119
|
+
assert old_rv in self.basic_RVs, "old_rv is not part of the Model"
|
|
120
|
+
assert new_rv not in self.basic_RVs, "new_rv is already part of the Model"
|
|
121
|
+
|
|
122
|
+
self.named_vars.pop(old_rv.name)
|
|
123
|
+
new_rv.name = old_rv.name
|
|
124
|
+
self.named_vars[new_rv.name] = new_rv
|
|
125
|
+
if old_rv in self.named_vars_to_dims:
|
|
126
|
+
self._RV_dims[new_rv] = self._RV_dims.pop(old_rv)
|
|
127
|
+
|
|
128
|
+
value = self.rvs_to_values.pop(old_rv)
|
|
129
|
+
self.rvs_to_values[new_rv] = value
|
|
130
|
+
self.values_to_rvs[value] = new_rv
|
|
131
|
+
|
|
132
|
+
self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
|
|
133
|
+
if old_rv in self.free_RVs:
|
|
134
|
+
index = self.free_RVs.index(old_rv)
|
|
135
|
+
self.free_RVs.pop(index)
|
|
136
|
+
self.free_RVs.insert(index, new_rv)
|
|
137
|
+
self.rvs_to_initial_values[new_rv] = self.rvs_to_initial_values.pop(old_rv)
|
|
138
|
+
elif old_rv in self.observed_RVs:
|
|
139
|
+
index = self.observed_RVs.index(old_rv)
|
|
140
|
+
self.observed_RVs.pop(index)
|
|
141
|
+
self.observed_RVs.insert(index, new_rv)
|
|
142
|
+
|
|
143
|
+
def _marginalize(self, user_warnings=False):
|
|
144
|
+
fg = FunctionGraph(outputs=self.basic_RVs + self.marginalized_rvs, clone=False)
|
|
145
|
+
|
|
146
|
+
toposort = fg.toposort()
|
|
147
|
+
rvs_left_to_marginalize = self.marginalized_rvs
|
|
148
|
+
for rv_to_marginalize in sorted(
|
|
149
|
+
self.marginalized_rvs,
|
|
150
|
+
key=lambda rv: toposort.index(rv.owner),
|
|
151
|
+
reverse=True,
|
|
152
|
+
):
|
|
153
|
+
# Check that no deterministics or potentials dependend on the rv to marginalize
|
|
154
|
+
for det in self.deterministics:
|
|
155
|
+
if is_conditional_dependent(
|
|
156
|
+
det, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
|
|
157
|
+
):
|
|
158
|
+
raise NotImplementedError(
|
|
159
|
+
f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}"
|
|
160
|
+
)
|
|
161
|
+
for pot in self.potentials:
|
|
162
|
+
if is_conditional_dependent(
|
|
163
|
+
pot, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
|
|
164
|
+
):
|
|
165
|
+
raise NotImplementedError(
|
|
166
|
+
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
|
|
170
|
+
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if user_warnings and len(new_rvs) > 2:
|
|
174
|
+
warnings.warn(
|
|
175
|
+
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
|
|
176
|
+
f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}",
|
|
177
|
+
UserWarning,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
rvs_left_to_marginalize.remove(rv_to_marginalize)
|
|
181
|
+
|
|
182
|
+
for old_rv, new_rv in zip(old_rvs, new_rvs):
|
|
183
|
+
new_rv.name = old_rv.name
|
|
184
|
+
if old_rv in self.marginalized_rvs:
|
|
185
|
+
idx = self.marginalized_rvs.index(old_rv)
|
|
186
|
+
self.marginalized_rvs.pop(idx)
|
|
187
|
+
self.marginalized_rvs.insert(idx, new_rv)
|
|
188
|
+
if old_rv in self.basic_RVs:
|
|
189
|
+
self._transfer_rv_mappings(old_rv, new_rv)
|
|
190
|
+
if user_warnings:
|
|
191
|
+
# Interval transforms for dependent variable won't work for non-constant bounds because
|
|
192
|
+
# the RV inputs are now different and may depend on another RV that also depends on the
|
|
193
|
+
# same marginalized RV
|
|
194
|
+
transform = self.rvs_to_transforms[new_rv]
|
|
195
|
+
if isinstance(transform, IntervalTransform) or (
|
|
196
|
+
isinstance(transform, Chain)
|
|
197
|
+
and any(
|
|
198
|
+
isinstance(tr, IntervalTransform) for tr in transform.transform_list
|
|
199
|
+
)
|
|
200
|
+
):
|
|
201
|
+
warnings.warn(
|
|
202
|
+
f"The transform {transform} for the variable {old_rv}, which depends on the "
|
|
203
|
+
f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
|
|
204
|
+
UserWarning,
|
|
205
|
+
)
|
|
206
|
+
return self
|
|
207
|
+
|
|
208
|
+
def _logp(self, *args, **kwargs):
|
|
209
|
+
return super().logp(*args, **kwargs)
|
|
210
|
+
|
|
211
|
+
def logp(self, vars=None, **kwargs):
|
|
212
|
+
m = self.clone()._marginalize()
|
|
213
|
+
if vars is not None:
|
|
214
|
+
if not isinstance(vars, Sequence):
|
|
215
|
+
vars = (vars,)
|
|
216
|
+
vars = [m[var.name] for var in vars]
|
|
217
|
+
return m._logp(vars=vars, **kwargs)
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel":
|
|
221
|
+
new_model = MarginalModel(coords=model.coords)
|
|
222
|
+
if isinstance(model, MarginalModel):
|
|
223
|
+
marginalized_rvs = model.marginalized_rvs
|
|
224
|
+
marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
|
|
225
|
+
else:
|
|
226
|
+
marginalized_rvs = []
|
|
227
|
+
marginalized_named_vars_to_dims = {}
|
|
228
|
+
|
|
229
|
+
model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
|
|
230
|
+
data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
|
|
231
|
+
vars = model_vars + data_vars
|
|
232
|
+
cloned_vars = clone_replace(vars)
|
|
233
|
+
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
|
|
234
|
+
new_model.vars_to_clone = vars_to_clone
|
|
235
|
+
|
|
236
|
+
new_model.named_vars = treedict(
|
|
237
|
+
{name: vars_to_clone[var] for name, var in model.named_vars.items()}
|
|
238
|
+
)
|
|
239
|
+
new_model.named_vars_to_dims = model.named_vars_to_dims
|
|
240
|
+
new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()}
|
|
241
|
+
new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()}
|
|
242
|
+
new_model.rvs_to_transforms = {
|
|
243
|
+
vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items()
|
|
244
|
+
}
|
|
245
|
+
new_model.rvs_to_initial_values = {
|
|
246
|
+
vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items()
|
|
247
|
+
}
|
|
248
|
+
new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
|
|
249
|
+
new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
|
|
250
|
+
new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
|
|
251
|
+
new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]
|
|
252
|
+
|
|
253
|
+
new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
|
|
254
|
+
new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
|
|
255
|
+
return new_model
|
|
256
|
+
|
|
257
|
+
def clone(self):
|
|
258
|
+
return self.from_model(self)
|
|
259
|
+
|
|
260
|
+
def marginalize(
|
|
261
|
+
self,
|
|
262
|
+
rvs_to_marginalize: ModelRVs,
|
|
263
|
+
):
|
|
264
|
+
if not isinstance(rvs_to_marginalize, Sequence):
|
|
265
|
+
rvs_to_marginalize = (rvs_to_marginalize,)
|
|
266
|
+
rvs_to_marginalize = [
|
|
267
|
+
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
for rv_to_marginalize in rvs_to_marginalize:
|
|
271
|
+
if rv_to_marginalize not in self.free_RVs:
|
|
272
|
+
raise ValueError(
|
|
273
|
+
f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
rv_op = rv_to_marginalize.owner.op
|
|
277
|
+
if isinstance(rv_op, DiscreteMarkovChain):
|
|
278
|
+
if rv_op.n_lags > 1:
|
|
279
|
+
raise NotImplementedError(
|
|
280
|
+
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
|
|
281
|
+
)
|
|
282
|
+
if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
|
|
283
|
+
raise NotImplementedError(
|
|
284
|
+
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
|
|
285
|
+
)
|
|
286
|
+
elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
|
|
287
|
+
raise NotImplementedError(
|
|
288
|
+
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
if rv_to_marginalize.name in self.named_vars_to_dims:
|
|
292
|
+
dims = self.named_vars_to_dims[rv_to_marginalize.name]
|
|
293
|
+
self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims
|
|
294
|
+
|
|
295
|
+
self._delete_rv_mappings(rv_to_marginalize)
|
|
296
|
+
self.marginalized_rvs.append(rv_to_marginalize)
|
|
297
|
+
|
|
298
|
+
# Raise errors and warnings immediately
|
|
299
|
+
self.clone()._marginalize(user_warnings=True)
|
|
300
|
+
|
|
301
|
+
def _to_transformed(self):
|
|
302
|
+
"""Create a function from the untransformed space to the transformed space"""
|
|
303
|
+
transformed_rvs = []
|
|
304
|
+
transformed_names = []
|
|
305
|
+
|
|
306
|
+
for rv in self.free_RVs:
|
|
307
|
+
transform = self.rvs_to_transforms.get(rv)
|
|
308
|
+
if transform is None:
|
|
309
|
+
transformed_rvs.append(rv)
|
|
310
|
+
transformed_names.append(rv.name)
|
|
311
|
+
else:
|
|
312
|
+
transformed_rv = transform.forward(rv, *rv.owner.inputs)
|
|
313
|
+
transformed_rvs.append(transformed_rv)
|
|
314
|
+
transformed_names.append(self.rvs_to_values[rv].name)
|
|
315
|
+
|
|
316
|
+
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
|
|
317
|
+
return fn, transformed_names
|
|
318
|
+
|
|
319
|
+
def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]):
|
|
320
|
+
for rv in rvs_to_unmarginalize:
|
|
321
|
+
if isinstance(rv, str):
|
|
322
|
+
rv = self[rv]
|
|
323
|
+
self.marginalized_rvs.remove(rv)
|
|
324
|
+
if rv.name in self._marginalized_named_vars_to_dims:
|
|
325
|
+
dims = self._marginalized_named_vars_to_dims.pop(rv.name)
|
|
326
|
+
else:
|
|
327
|
+
dims = None
|
|
328
|
+
self.register_rv(rv, name=rv.name, dims=dims)
|
|
329
|
+
|
|
330
|
+
def recover_marginals(
|
|
331
|
+
self,
|
|
332
|
+
idata: InferenceData,
|
|
333
|
+
var_names: Sequence[str] | None = None,
|
|
334
|
+
return_samples: bool = True,
|
|
335
|
+
extend_inferencedata: bool = True,
|
|
336
|
+
random_seed: RandomState = None,
|
|
337
|
+
):
|
|
338
|
+
"""Computes posterior log-probabilities and samples of marginalized variables
|
|
339
|
+
conditioned on parameters of the model given InferenceData with posterior group
|
|
340
|
+
|
|
341
|
+
When there are multiple marginalized variables, each marginalized variable is
|
|
342
|
+
conditioned on both the parameters and the other variables still marginalized
|
|
343
|
+
|
|
344
|
+
All log-probabilities are within the transformed space
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
idata : InferenceData
|
|
349
|
+
InferenceData with posterior group
|
|
350
|
+
var_names : sequence of str, optional
|
|
351
|
+
List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
|
|
352
|
+
return_samples : bool, default True
|
|
353
|
+
If True, also return samples of the marginalized variables
|
|
354
|
+
extend_inferencedata : bool, default True
|
|
355
|
+
Whether to extend the original InferenceData or return a new one
|
|
356
|
+
random_seed: int, array-like of int or SeedSequence, optional
|
|
357
|
+
Seed used to generating samples
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
idata : InferenceData
|
|
362
|
+
InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
|
|
363
|
+
|
|
364
|
+
.. code-block:: python
|
|
365
|
+
|
|
366
|
+
import pymc as pm
|
|
367
|
+
from pymc_extras import MarginalModel
|
|
368
|
+
|
|
369
|
+
with MarginalModel() as m:
|
|
370
|
+
p = pm.Beta("p", 1, 1)
|
|
371
|
+
x = pm.Bernoulli("x", p=p, shape=(3,))
|
|
372
|
+
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
|
|
373
|
+
|
|
374
|
+
m.marginalize([x])
|
|
375
|
+
|
|
376
|
+
idata = pm.sample()
|
|
377
|
+
m.recover_marginals(idata, var_names=["x"])
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
"""
|
|
381
|
+
if var_names is None:
|
|
382
|
+
var_names = [var.name for var in self.marginalized_rvs]
|
|
383
|
+
|
|
384
|
+
var_names = [var if isinstance(var, str) else var.name for var in var_names]
|
|
385
|
+
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
|
|
386
|
+
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
|
|
387
|
+
if missing_names:
|
|
388
|
+
raise ValueError(f"Unrecognized var_names: {missing_names}")
|
|
389
|
+
|
|
390
|
+
if return_samples and random_seed is not None:
|
|
391
|
+
seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
|
|
392
|
+
else:
|
|
393
|
+
seeds = [None] * len(vars_to_recover)
|
|
394
|
+
|
|
395
|
+
posterior = idata.posterior
|
|
396
|
+
|
|
397
|
+
# Remove Deterministics
|
|
398
|
+
posterior_values = posterior[
|
|
399
|
+
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
sample_dims = ("chain", "draw")
|
|
403
|
+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
|
|
404
|
+
|
|
405
|
+
# Handle Transforms
|
|
406
|
+
transform_fn, transform_names = self._to_transformed()
|
|
407
|
+
|
|
408
|
+
def transform_input(inputs):
|
|
409
|
+
return dict(zip(transform_names, transform_fn(inputs)))
|
|
410
|
+
|
|
411
|
+
posterior_pts = [transform_input(vs) for vs in posterior_pts]
|
|
412
|
+
|
|
413
|
+
rv_dict = {}
|
|
414
|
+
rv_dims = {}
|
|
415
|
+
for seed, marginalized_rv in zip(seeds, vars_to_recover):
|
|
416
|
+
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
|
|
417
|
+
if not isinstance(marginalized_rv.owner.op, supported_dists):
|
|
418
|
+
raise NotImplementedError(
|
|
419
|
+
f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
|
|
420
|
+
f"Supported distribution include {supported_dists}"
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
m = self.clone()
|
|
424
|
+
marginalized_rv = m.vars_to_clone[marginalized_rv]
|
|
425
|
+
m.unmarginalize([marginalized_rv])
|
|
426
|
+
dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
|
|
427
|
+
logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False)
|
|
428
|
+
|
|
429
|
+
# Handle batch dims for marginalized value and its dependent RVs
|
|
430
|
+
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
|
|
431
|
+
marginalized_rv, dependent_rvs
|
|
432
|
+
)
|
|
433
|
+
marginalized_logp, *dependent_logps = logps
|
|
434
|
+
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
|
|
435
|
+
dependent_rvs_dim_connections,
|
|
436
|
+
[dependent_var.owner.op for dependent_var in dependent_rvs],
|
|
437
|
+
dependent_logps,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
marginalized_value = m.rvs_to_values[marginalized_rv]
|
|
441
|
+
other_values = [v for v in m.value_vars if v is not marginalized_value]
|
|
442
|
+
|
|
443
|
+
rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
|
|
444
|
+
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
|
|
445
|
+
rv_domain_tensor = pt.moveaxis(
|
|
446
|
+
pt.full(
|
|
447
|
+
(*rv_shape, len(rv_domain)),
|
|
448
|
+
rv_domain,
|
|
449
|
+
dtype=marginalized_rv.dtype,
|
|
450
|
+
),
|
|
451
|
+
-1,
|
|
452
|
+
0,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
batched_joint_logp = vectorize_graph(
|
|
456
|
+
joint_logp,
|
|
457
|
+
replace={marginalized_value: rv_domain_tensor},
|
|
458
|
+
)
|
|
459
|
+
batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1)
|
|
460
|
+
|
|
461
|
+
joint_logp_norm = log_softmax(batched_joint_logp, axis=-1)
|
|
462
|
+
if return_samples:
|
|
463
|
+
rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp)
|
|
464
|
+
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
|
|
465
|
+
rv_draws += rv_domain[0]
|
|
466
|
+
outputs = [joint_logp_norm, rv_draws]
|
|
467
|
+
else:
|
|
468
|
+
outputs = joint_logp_norm
|
|
469
|
+
|
|
470
|
+
rv_loglike_fn = compile_pymc(
|
|
471
|
+
inputs=other_values,
|
|
472
|
+
outputs=outputs,
|
|
473
|
+
on_unused_input="ignore",
|
|
474
|
+
random_seed=seed,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
|
|
478
|
+
|
|
479
|
+
if return_samples:
|
|
480
|
+
logps, samples = zip(*logvs)
|
|
481
|
+
logps = np.array(logps)
|
|
482
|
+
samples = np.array(samples)
|
|
483
|
+
rv_dict[marginalized_rv.name] = samples.reshape(
|
|
484
|
+
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
|
|
485
|
+
)
|
|
486
|
+
else:
|
|
487
|
+
logps = np.array(logvs)
|
|
488
|
+
|
|
489
|
+
rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
|
|
490
|
+
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
|
|
491
|
+
)
|
|
492
|
+
if marginalized_rv.name in m.named_vars_to_dims:
|
|
493
|
+
rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
|
|
494
|
+
rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
|
|
495
|
+
"lp_" + marginalized_rv.name + "_dim"
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
coords, dims = coords_and_dims_for_inferencedata(self)
|
|
499
|
+
dims.update(rv_dims)
|
|
500
|
+
rv_dataset = dict_to_dataset(
|
|
501
|
+
rv_dict,
|
|
502
|
+
library=pymc,
|
|
503
|
+
dims=dims,
|
|
504
|
+
coords=coords,
|
|
505
|
+
default_dims=list(sample_dims),
|
|
506
|
+
skip_event_dims=True,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
if extend_inferencedata:
|
|
510
|
+
idata.posterior = idata.posterior.assign(rv_dataset)
|
|
511
|
+
return idata
|
|
512
|
+
else:
|
|
513
|
+
return rv_dataset
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
|
|
517
|
+
"""Marginalize a subset of variables in a PyMC model.
|
|
518
|
+
|
|
519
|
+
This creates a class of `MarginalModel` from an existing `Model`, with the specified
|
|
520
|
+
variables marginalized.
|
|
521
|
+
|
|
522
|
+
See documentation for `MarginalModel` for more information.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
model : Model
|
|
527
|
+
PyMC model to marginalize. Original variables well be cloned.
|
|
528
|
+
rvs_to_marginalize : Sequence[TensorVariable]
|
|
529
|
+
Variables to marginalize in the returned model.
|
|
530
|
+
|
|
531
|
+
Returns
|
|
532
|
+
-------
|
|
533
|
+
marginal_model: MarginalModel
|
|
534
|
+
Marginal model with the specified variables marginalized.
|
|
535
|
+
"""
|
|
536
|
+
if not isinstance(rvs_to_marginalize, tuple | list):
|
|
537
|
+
rvs_to_marginalize = (rvs_to_marginalize,)
|
|
538
|
+
rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize]
|
|
539
|
+
|
|
540
|
+
marginal_model = MarginalModel.from_model(model)
|
|
541
|
+
marginal_model.marginalize(rvs_to_marginalize)
|
|
542
|
+
return marginal_model
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def collect_shared_vars(outputs, blockers):
|
|
546
|
+
return [
|
|
547
|
+
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
|
|
548
|
+
]
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
|
|
552
|
+
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
|
|
553
|
+
if not dependent_rvs:
|
|
554
|
+
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
|
|
555
|
+
|
|
556
|
+
marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
|
|
557
|
+
other_direct_rv_ancestors = [
|
|
558
|
+
rv
|
|
559
|
+
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
|
|
560
|
+
if rv is not rv_to_marginalize
|
|
561
|
+
]
|
|
562
|
+
|
|
563
|
+
# If the marginalized RV has multiple dimensions, check that graph between
|
|
564
|
+
# marginalized RV and dependent RVs does not mix information from batch dimensions
|
|
565
|
+
# (otherwise logp would require enumerating over all combinations of batch dimension values)
|
|
566
|
+
try:
|
|
567
|
+
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
|
|
568
|
+
rv_to_marginalize, dependent_rvs
|
|
569
|
+
)
|
|
570
|
+
except (ValueError, NotImplementedError) as e:
|
|
571
|
+
# For the perspective of the user this is a NotImplementedError
|
|
572
|
+
raise NotImplementedError(
|
|
573
|
+
"The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
|
|
574
|
+
"You can try splitting the marginalized RV into separate components and marginalizing them separately."
|
|
575
|
+
) from e
|
|
576
|
+
|
|
577
|
+
input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)))
|
|
578
|
+
output_rvs = [rv_to_marginalize, *dependent_rvs]
|
|
579
|
+
|
|
580
|
+
# We are strict about shared variables in SymbolicRandomVariables
|
|
581
|
+
inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs)
|
|
582
|
+
|
|
583
|
+
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
|
|
584
|
+
marginalize_constructor = MarginalDiscreteMarkovChainRV
|
|
585
|
+
else:
|
|
586
|
+
marginalize_constructor = MarginalFiniteDiscreteRV
|
|
587
|
+
|
|
588
|
+
marginalization_op = marginalize_constructor(
|
|
589
|
+
inputs=inputs,
|
|
590
|
+
outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph
|
|
591
|
+
dims_connections=dependent_rvs_dim_connections,
|
|
592
|
+
)
|
|
593
|
+
new_output_rvs = marginalization_op(*inputs)
|
|
594
|
+
fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs)))
|
|
595
|
+
return output_rvs, new_output_rvs
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
|
|
3
|
+
from pymc import Model
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def as_model(*model_args, **model_kwargs):
|
|
7
|
+
R"""
|
|
8
|
+
Decorator to provide context to PyMC models declared in a function.
|
|
9
|
+
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
|
|
10
|
+
Additionally, a coords argument is added to the function so coords can be changed during function invocation
|
|
11
|
+
|
|
12
|
+
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
|
|
13
|
+
|
|
14
|
+
Examples
|
|
15
|
+
--------
|
|
16
|
+
.. code:: python
|
|
17
|
+
|
|
18
|
+
import pymc as pm
|
|
19
|
+
import pymc_extras as pmx
|
|
20
|
+
|
|
21
|
+
# The following are equivalent
|
|
22
|
+
|
|
23
|
+
# standard PyMC API with context manager
|
|
24
|
+
with pm.Model(coords={"obs": ["a", "b"]}) as model:
|
|
25
|
+
x = pm.Normal("x", 0., 1., dims="obs")
|
|
26
|
+
pm.sample()
|
|
27
|
+
|
|
28
|
+
# functional API using decorator
|
|
29
|
+
@pmx.as_model(coords={"obs": ["a", "b"]})
|
|
30
|
+
def basic_model():
|
|
31
|
+
pm.Normal("x", 0., 1., dims="obs")
|
|
32
|
+
|
|
33
|
+
m = basic_model()
|
|
34
|
+
pm.sample(model=m)
|
|
35
|
+
|
|
36
|
+
# alternative way to use functional API
|
|
37
|
+
@pmx.as_model()
|
|
38
|
+
def basic_model():
|
|
39
|
+
pm.Normal("x", 0., 1., dims="obs")
|
|
40
|
+
|
|
41
|
+
m = basic_model(coords={"obs": ["a", "b"]})
|
|
42
|
+
pm.sample(model=m)
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def decorator(f):
|
|
47
|
+
@wraps(f)
|
|
48
|
+
def make_model(*args, **kwargs):
|
|
49
|
+
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
|
|
50
|
+
with Model(*model_args, coords=coords, **model_kwargs) as m:
|
|
51
|
+
f(*args, **kwargs)
|
|
52
|
+
return m
|
|
53
|
+
|
|
54
|
+
return make_model
|
|
55
|
+
|
|
56
|
+
return decorator
|
|
File without changes
|