pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,595 @@
1
+ import warnings
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import pymc
8
+ import pytensor.tensor as pt
9
+
10
+ from arviz import InferenceData, dict_to_dataset
11
+ from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
12
+ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
13
+ from pymc.distributions.transforms import Chain
14
+ from pymc.logprob.transforms import IntervalTransform
15
+ from pymc.model import Model
16
+ from pymc.pytensorf import compile_pymc, constant_fold
17
+ from pymc.util import RandomState, _get_seeds_per_chain, treedict
18
+ from pytensor.compile import SharedVariable
19
+ from pytensor.graph import FunctionGraph, clone_replace, graph_inputs
20
+ from pytensor.graph.replace import vectorize_graph
21
+ from pytensor.tensor import TensorVariable
22
+ from pytensor.tensor.special import log_softmax
23
+
24
+ __all__ = ["MarginalModel", "marginalize"]
25
+
26
+ from pymc_extras.distributions import DiscreteMarkovChain
27
+ from pymc_extras.model.marginal.distributions import (
28
+ MarginalDiscreteMarkovChainRV,
29
+ MarginalFiniteDiscreteRV,
30
+ get_domain_of_finite_discrete_rv,
31
+ reduce_batch_dependent_logps,
32
+ )
33
+ from pymc_extras.model.marginal.graph_analysis import (
34
+ find_conditional_dependent_rvs,
35
+ find_conditional_input_rvs,
36
+ is_conditional_dependent,
37
+ subgraph_batch_dim_connection,
38
+ )
39
+
40
+ ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
41
+
42
+
43
+ class MarginalModel(Model):
44
+ """Subclass of PyMC Model that implements functionality for automatic
45
+ marginalization of variables in the logp transformation
46
+
47
+ After defining the full Model, the `marginalize` method can be used to indicate a
48
+ subset of variables that should be marginalized
49
+
50
+ Notes
51
+ -----
52
+ Marginalization functionality is still very restricted. Only finite discrete
53
+ variables can be marginalized. Deterministics and Potentials cannot be conditionally
54
+ dependent on the marginalized variables.
55
+
56
+ Furthermore, not all instances of such variables can be marginalized. If a variable
57
+ has batched dimensions, it is required that any conditionally dependent variables
58
+ use information from an individual batched dimension. In other words, the graph
59
+ connecting the marginalized variable(s) to the dependent variable(s) must be
60
+ composed strictly of Elemwise Operations. This is necessary to ensure an efficient
61
+ logprob graph can be generated. If you want to bypass this restriction you can
62
+ separate each dimension of the marginalized variable into the scalar components
63
+ and then stack them together. Note that such graphs will grow exponentially in the
64
+ number of marginalized variables.
65
+
66
+ For the same reason, it's not possible to marginalize RVs with multivariate
67
+ dependent RVs.
68
+
69
+ Examples
70
+ --------
71
+ Marginalize over a single variable
72
+
73
+ .. code-block:: python
74
+
75
+ import pymc as pm
76
+ from pymc_extras import MarginalModel
77
+
78
+ with MarginalModel() as m:
79
+ p = pm.Beta("p", 1, 1)
80
+ x = pm.Bernoulli("x", p=p, shape=(3,))
81
+ y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
82
+
83
+ m.marginalize([x])
84
+
85
+ idata = pm.sample()
86
+
87
+ """
88
+
89
+ def __init__(self, *args, **kwargs):
90
+ super().__init__(*args, **kwargs)
91
+ self.marginalized_rvs = []
92
+ self._marginalized_named_vars_to_dims = {}
93
+
94
+ def _delete_rv_mappings(self, rv: TensorVariable) -> None:
95
+ """Remove all model mappings referring to rv
96
+
97
+ This can be used to "delete" an RV from a model
98
+ """
99
+ assert rv in self.basic_RVs, "rv is not part of the Model"
100
+
101
+ name = rv.name
102
+ self.named_vars.pop(name)
103
+ if name in self.named_vars_to_dims:
104
+ self.named_vars_to_dims.pop(name)
105
+
106
+ value = self.rvs_to_values.pop(rv)
107
+ self.values_to_rvs.pop(value)
108
+
109
+ self.rvs_to_transforms.pop(rv)
110
+ if rv in self.free_RVs:
111
+ self.free_RVs.remove(rv)
112
+ self.rvs_to_initial_values.pop(rv)
113
+ else:
114
+ self.observed_RVs.remove(rv)
115
+
116
+ def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None:
117
+ """Transfer model mappings from old_rv to new_rv"""
118
+
119
+ assert old_rv in self.basic_RVs, "old_rv is not part of the Model"
120
+ assert new_rv not in self.basic_RVs, "new_rv is already part of the Model"
121
+
122
+ self.named_vars.pop(old_rv.name)
123
+ new_rv.name = old_rv.name
124
+ self.named_vars[new_rv.name] = new_rv
125
+ if old_rv in self.named_vars_to_dims:
126
+ self._RV_dims[new_rv] = self._RV_dims.pop(old_rv)
127
+
128
+ value = self.rvs_to_values.pop(old_rv)
129
+ self.rvs_to_values[new_rv] = value
130
+ self.values_to_rvs[value] = new_rv
131
+
132
+ self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv)
133
+ if old_rv in self.free_RVs:
134
+ index = self.free_RVs.index(old_rv)
135
+ self.free_RVs.pop(index)
136
+ self.free_RVs.insert(index, new_rv)
137
+ self.rvs_to_initial_values[new_rv] = self.rvs_to_initial_values.pop(old_rv)
138
+ elif old_rv in self.observed_RVs:
139
+ index = self.observed_RVs.index(old_rv)
140
+ self.observed_RVs.pop(index)
141
+ self.observed_RVs.insert(index, new_rv)
142
+
143
+ def _marginalize(self, user_warnings=False):
144
+ fg = FunctionGraph(outputs=self.basic_RVs + self.marginalized_rvs, clone=False)
145
+
146
+ toposort = fg.toposort()
147
+ rvs_left_to_marginalize = self.marginalized_rvs
148
+ for rv_to_marginalize in sorted(
149
+ self.marginalized_rvs,
150
+ key=lambda rv: toposort.index(rv.owner),
151
+ reverse=True,
152
+ ):
153
+ # Check that no deterministics or potentials dependend on the rv to marginalize
154
+ for det in self.deterministics:
155
+ if is_conditional_dependent(
156
+ det, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
157
+ ):
158
+ raise NotImplementedError(
159
+ f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}"
160
+ )
161
+ for pot in self.potentials:
162
+ if is_conditional_dependent(
163
+ pot, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
164
+ ):
165
+ raise NotImplementedError(
166
+ f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
167
+ )
168
+
169
+ old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
170
+ fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
171
+ )
172
+
173
+ if user_warnings and len(new_rvs) > 2:
174
+ warnings.warn(
175
+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
176
+ f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}",
177
+ UserWarning,
178
+ )
179
+
180
+ rvs_left_to_marginalize.remove(rv_to_marginalize)
181
+
182
+ for old_rv, new_rv in zip(old_rvs, new_rvs):
183
+ new_rv.name = old_rv.name
184
+ if old_rv in self.marginalized_rvs:
185
+ idx = self.marginalized_rvs.index(old_rv)
186
+ self.marginalized_rvs.pop(idx)
187
+ self.marginalized_rvs.insert(idx, new_rv)
188
+ if old_rv in self.basic_RVs:
189
+ self._transfer_rv_mappings(old_rv, new_rv)
190
+ if user_warnings:
191
+ # Interval transforms for dependent variable won't work for non-constant bounds because
192
+ # the RV inputs are now different and may depend on another RV that also depends on the
193
+ # same marginalized RV
194
+ transform = self.rvs_to_transforms[new_rv]
195
+ if isinstance(transform, IntervalTransform) or (
196
+ isinstance(transform, Chain)
197
+ and any(
198
+ isinstance(tr, IntervalTransform) for tr in transform.transform_list
199
+ )
200
+ ):
201
+ warnings.warn(
202
+ f"The transform {transform} for the variable {old_rv}, which depends on the "
203
+ f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
204
+ UserWarning,
205
+ )
206
+ return self
207
+
208
+ def _logp(self, *args, **kwargs):
209
+ return super().logp(*args, **kwargs)
210
+
211
+ def logp(self, vars=None, **kwargs):
212
+ m = self.clone()._marginalize()
213
+ if vars is not None:
214
+ if not isinstance(vars, Sequence):
215
+ vars = (vars,)
216
+ vars = [m[var.name] for var in vars]
217
+ return m._logp(vars=vars, **kwargs)
218
+
219
+ @staticmethod
220
+ def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel":
221
+ new_model = MarginalModel(coords=model.coords)
222
+ if isinstance(model, MarginalModel):
223
+ marginalized_rvs = model.marginalized_rvs
224
+ marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims
225
+ else:
226
+ marginalized_rvs = []
227
+ marginalized_named_vars_to_dims = {}
228
+
229
+ model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs
230
+ data_vars = [var for name, var in model.named_vars.items() if var not in model_vars]
231
+ vars = model_vars + data_vars
232
+ cloned_vars = clone_replace(vars)
233
+ vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
234
+ new_model.vars_to_clone = vars_to_clone
235
+
236
+ new_model.named_vars = treedict(
237
+ {name: vars_to_clone[var] for name, var in model.named_vars.items()}
238
+ )
239
+ new_model.named_vars_to_dims = model.named_vars_to_dims
240
+ new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()}
241
+ new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()}
242
+ new_model.rvs_to_transforms = {
243
+ vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items()
244
+ }
245
+ new_model.rvs_to_initial_values = {
246
+ vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items()
247
+ }
248
+ new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs]
249
+ new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs]
250
+ new_model.potentials = [vars_to_clone[pot] for pot in model.potentials]
251
+ new_model.deterministics = [vars_to_clone[det] for det in model.deterministics]
252
+
253
+ new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs]
254
+ new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims
255
+ return new_model
256
+
257
+ def clone(self):
258
+ return self.from_model(self)
259
+
260
+ def marginalize(
261
+ self,
262
+ rvs_to_marginalize: ModelRVs,
263
+ ):
264
+ if not isinstance(rvs_to_marginalize, Sequence):
265
+ rvs_to_marginalize = (rvs_to_marginalize,)
266
+ rvs_to_marginalize = [
267
+ self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
268
+ ]
269
+
270
+ for rv_to_marginalize in rvs_to_marginalize:
271
+ if rv_to_marginalize not in self.free_RVs:
272
+ raise ValueError(
273
+ f"Marginalized RV {rv_to_marginalize} is not a free RV in the model"
274
+ )
275
+
276
+ rv_op = rv_to_marginalize.owner.op
277
+ if isinstance(rv_op, DiscreteMarkovChain):
278
+ if rv_op.n_lags > 1:
279
+ raise NotImplementedError(
280
+ "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
281
+ )
282
+ if rv_to_marginalize.owner.inputs[0].type.ndim > 2:
283
+ raise NotImplementedError(
284
+ "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
285
+ )
286
+ elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform):
287
+ raise NotImplementedError(
288
+ f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
289
+ )
290
+
291
+ if rv_to_marginalize.name in self.named_vars_to_dims:
292
+ dims = self.named_vars_to_dims[rv_to_marginalize.name]
293
+ self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims
294
+
295
+ self._delete_rv_mappings(rv_to_marginalize)
296
+ self.marginalized_rvs.append(rv_to_marginalize)
297
+
298
+ # Raise errors and warnings immediately
299
+ self.clone()._marginalize(user_warnings=True)
300
+
301
+ def _to_transformed(self):
302
+ """Create a function from the untransformed space to the transformed space"""
303
+ transformed_rvs = []
304
+ transformed_names = []
305
+
306
+ for rv in self.free_RVs:
307
+ transform = self.rvs_to_transforms.get(rv)
308
+ if transform is None:
309
+ transformed_rvs.append(rv)
310
+ transformed_names.append(rv.name)
311
+ else:
312
+ transformed_rv = transform.forward(rv, *rv.owner.inputs)
313
+ transformed_rvs.append(transformed_rv)
314
+ transformed_names.append(self.rvs_to_values[rv].name)
315
+
316
+ fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
317
+ return fn, transformed_names
318
+
319
+ def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]):
320
+ for rv in rvs_to_unmarginalize:
321
+ if isinstance(rv, str):
322
+ rv = self[rv]
323
+ self.marginalized_rvs.remove(rv)
324
+ if rv.name in self._marginalized_named_vars_to_dims:
325
+ dims = self._marginalized_named_vars_to_dims.pop(rv.name)
326
+ else:
327
+ dims = None
328
+ self.register_rv(rv, name=rv.name, dims=dims)
329
+
330
+ def recover_marginals(
331
+ self,
332
+ idata: InferenceData,
333
+ var_names: Sequence[str] | None = None,
334
+ return_samples: bool = True,
335
+ extend_inferencedata: bool = True,
336
+ random_seed: RandomState = None,
337
+ ):
338
+ """Computes posterior log-probabilities and samples of marginalized variables
339
+ conditioned on parameters of the model given InferenceData with posterior group
340
+
341
+ When there are multiple marginalized variables, each marginalized variable is
342
+ conditioned on both the parameters and the other variables still marginalized
343
+
344
+ All log-probabilities are within the transformed space
345
+
346
+ Parameters
347
+ ----------
348
+ idata : InferenceData
349
+ InferenceData with posterior group
350
+ var_names : sequence of str, optional
351
+ List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
352
+ return_samples : bool, default True
353
+ If True, also return samples of the marginalized variables
354
+ extend_inferencedata : bool, default True
355
+ Whether to extend the original InferenceData or return a new one
356
+ random_seed: int, array-like of int or SeedSequence, optional
357
+ Seed used to generating samples
358
+
359
+ Returns
360
+ -------
361
+ idata : InferenceData
362
+ InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
363
+
364
+ .. code-block:: python
365
+
366
+ import pymc as pm
367
+ from pymc_extras import MarginalModel
368
+
369
+ with MarginalModel() as m:
370
+ p = pm.Beta("p", 1, 1)
371
+ x = pm.Bernoulli("x", p=p, shape=(3,))
372
+ y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
373
+
374
+ m.marginalize([x])
375
+
376
+ idata = pm.sample()
377
+ m.recover_marginals(idata, var_names=["x"])
378
+
379
+
380
+ """
381
+ if var_names is None:
382
+ var_names = [var.name for var in self.marginalized_rvs]
383
+
384
+ var_names = [var if isinstance(var, str) else var.name for var in var_names]
385
+ vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
386
+ missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
387
+ if missing_names:
388
+ raise ValueError(f"Unrecognized var_names: {missing_names}")
389
+
390
+ if return_samples and random_seed is not None:
391
+ seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
392
+ else:
393
+ seeds = [None] * len(vars_to_recover)
394
+
395
+ posterior = idata.posterior
396
+
397
+ # Remove Deterministics
398
+ posterior_values = posterior[
399
+ [rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
400
+ ]
401
+
402
+ sample_dims = ("chain", "draw")
403
+ posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
404
+
405
+ # Handle Transforms
406
+ transform_fn, transform_names = self._to_transformed()
407
+
408
+ def transform_input(inputs):
409
+ return dict(zip(transform_names, transform_fn(inputs)))
410
+
411
+ posterior_pts = [transform_input(vs) for vs in posterior_pts]
412
+
413
+ rv_dict = {}
414
+ rv_dims = {}
415
+ for seed, marginalized_rv in zip(seeds, vars_to_recover):
416
+ supported_dists = (Bernoulli, Categorical, DiscreteUniform)
417
+ if not isinstance(marginalized_rv.owner.op, supported_dists):
418
+ raise NotImplementedError(
419
+ f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. "
420
+ f"Supported distribution include {supported_dists}"
421
+ )
422
+
423
+ m = self.clone()
424
+ marginalized_rv = m.vars_to_clone[marginalized_rv]
425
+ m.unmarginalize([marginalized_rv])
426
+ dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
427
+ logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False)
428
+
429
+ # Handle batch dims for marginalized value and its dependent RVs
430
+ dependent_rvs_dim_connections = subgraph_batch_dim_connection(
431
+ marginalized_rv, dependent_rvs
432
+ )
433
+ marginalized_logp, *dependent_logps = logps
434
+ joint_logp = marginalized_logp + reduce_batch_dependent_logps(
435
+ dependent_rvs_dim_connections,
436
+ [dependent_var.owner.op for dependent_var in dependent_rvs],
437
+ dependent_logps,
438
+ )
439
+
440
+ marginalized_value = m.rvs_to_values[marginalized_rv]
441
+ other_values = [v for v in m.value_vars if v is not marginalized_value]
442
+
443
+ rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
444
+ rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
445
+ rv_domain_tensor = pt.moveaxis(
446
+ pt.full(
447
+ (*rv_shape, len(rv_domain)),
448
+ rv_domain,
449
+ dtype=marginalized_rv.dtype,
450
+ ),
451
+ -1,
452
+ 0,
453
+ )
454
+
455
+ batched_joint_logp = vectorize_graph(
456
+ joint_logp,
457
+ replace={marginalized_value: rv_domain_tensor},
458
+ )
459
+ batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1)
460
+
461
+ joint_logp_norm = log_softmax(batched_joint_logp, axis=-1)
462
+ if return_samples:
463
+ rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp)
464
+ if isinstance(marginalized_rv.owner.op, DiscreteUniform):
465
+ rv_draws += rv_domain[0]
466
+ outputs = [joint_logp_norm, rv_draws]
467
+ else:
468
+ outputs = joint_logp_norm
469
+
470
+ rv_loglike_fn = compile_pymc(
471
+ inputs=other_values,
472
+ outputs=outputs,
473
+ on_unused_input="ignore",
474
+ random_seed=seed,
475
+ )
476
+
477
+ logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
478
+
479
+ if return_samples:
480
+ logps, samples = zip(*logvs)
481
+ logps = np.array(logps)
482
+ samples = np.array(samples)
483
+ rv_dict[marginalized_rv.name] = samples.reshape(
484
+ tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
485
+ )
486
+ else:
487
+ logps = np.array(logvs)
488
+
489
+ rv_dict["lp_" + marginalized_rv.name] = logps.reshape(
490
+ tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
491
+ )
492
+ if marginalized_rv.name in m.named_vars_to_dims:
493
+ rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name])
494
+ rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [
495
+ "lp_" + marginalized_rv.name + "_dim"
496
+ ]
497
+
498
+ coords, dims = coords_and_dims_for_inferencedata(self)
499
+ dims.update(rv_dims)
500
+ rv_dataset = dict_to_dataset(
501
+ rv_dict,
502
+ library=pymc,
503
+ dims=dims,
504
+ coords=coords,
505
+ default_dims=list(sample_dims),
506
+ skip_event_dims=True,
507
+ )
508
+
509
+ if extend_inferencedata:
510
+ idata.posterior = idata.posterior.assign(rv_dataset)
511
+ return idata
512
+ else:
513
+ return rv_dataset
514
+
515
+
516
+ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
517
+ """Marginalize a subset of variables in a PyMC model.
518
+
519
+ This creates a class of `MarginalModel` from an existing `Model`, with the specified
520
+ variables marginalized.
521
+
522
+ See documentation for `MarginalModel` for more information.
523
+
524
+ Parameters
525
+ ----------
526
+ model : Model
527
+ PyMC model to marginalize. Original variables well be cloned.
528
+ rvs_to_marginalize : Sequence[TensorVariable]
529
+ Variables to marginalize in the returned model.
530
+
531
+ Returns
532
+ -------
533
+ marginal_model: MarginalModel
534
+ Marginal model with the specified variables marginalized.
535
+ """
536
+ if not isinstance(rvs_to_marginalize, tuple | list):
537
+ rvs_to_marginalize = (rvs_to_marginalize,)
538
+ rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize]
539
+
540
+ marginal_model = MarginalModel.from_model(model)
541
+ marginal_model.marginalize(rvs_to_marginalize)
542
+ return marginal_model
543
+
544
+
545
+ def collect_shared_vars(outputs, blockers):
546
+ return [
547
+ inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
548
+ ]
549
+
550
+
551
+ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
552
+ dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
553
+ if not dependent_rvs:
554
+ raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
555
+
556
+ marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
557
+ other_direct_rv_ancestors = [
558
+ rv
559
+ for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
560
+ if rv is not rv_to_marginalize
561
+ ]
562
+
563
+ # If the marginalized RV has multiple dimensions, check that graph between
564
+ # marginalized RV and dependent RVs does not mix information from batch dimensions
565
+ # (otherwise logp would require enumerating over all combinations of batch dimension values)
566
+ try:
567
+ dependent_rvs_dim_connections = subgraph_batch_dim_connection(
568
+ rv_to_marginalize, dependent_rvs
569
+ )
570
+ except (ValueError, NotImplementedError) as e:
571
+ # For the perspective of the user this is a NotImplementedError
572
+ raise NotImplementedError(
573
+ "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
574
+ "You can try splitting the marginalized RV into separate components and marginalizing them separately."
575
+ ) from e
576
+
577
+ input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)))
578
+ output_rvs = [rv_to_marginalize, *dependent_rvs]
579
+
580
+ # We are strict about shared variables in SymbolicRandomVariables
581
+ inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs)
582
+
583
+ if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
584
+ marginalize_constructor = MarginalDiscreteMarkovChainRV
585
+ else:
586
+ marginalize_constructor = MarginalFiniteDiscreteRV
587
+
588
+ marginalization_op = marginalize_constructor(
589
+ inputs=inputs,
590
+ outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph
591
+ dims_connections=dependent_rvs_dim_connections,
592
+ )
593
+ new_output_rvs = marginalization_op(*inputs)
594
+ fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs)))
595
+ return output_rvs, new_output_rvs
@@ -0,0 +1,56 @@
1
+ from functools import wraps
2
+
3
+ from pymc import Model
4
+
5
+
6
+ def as_model(*model_args, **model_kwargs):
7
+ R"""
8
+ Decorator to provide context to PyMC models declared in a function.
9
+ This removes all need to think about context managers and lets you separate creating a generative model from using the model.
10
+ Additionally, a coords argument is added to the function so coords can be changed during function invocation
11
+
12
+ Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
13
+
14
+ Examples
15
+ --------
16
+ .. code:: python
17
+
18
+ import pymc as pm
19
+ import pymc_extras as pmx
20
+
21
+ # The following are equivalent
22
+
23
+ # standard PyMC API with context manager
24
+ with pm.Model(coords={"obs": ["a", "b"]}) as model:
25
+ x = pm.Normal("x", 0., 1., dims="obs")
26
+ pm.sample()
27
+
28
+ # functional API using decorator
29
+ @pmx.as_model(coords={"obs": ["a", "b"]})
30
+ def basic_model():
31
+ pm.Normal("x", 0., 1., dims="obs")
32
+
33
+ m = basic_model()
34
+ pm.sample(model=m)
35
+
36
+ # alternative way to use functional API
37
+ @pmx.as_model()
38
+ def basic_model():
39
+ pm.Normal("x", 0., 1., dims="obs")
40
+
41
+ m = basic_model(coords={"obs": ["a", "b"]})
42
+ pm.sample(model=m)
43
+
44
+ """
45
+
46
+ def decorator(f):
47
+ @wraps(f)
48
+ def make_model(*args, **kwargs):
49
+ coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
50
+ with Model(*model_args, coords=coords, **model_kwargs) as m:
51
+ f(*args, **kwargs)
52
+ return m
53
+
54
+ return make_model
55
+
56
+ return decorator
File without changes