pymc-extras 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymc_extras/__init__.py CHANGED
@@ -16,7 +16,11 @@ import logging
16
16
  from pymc_extras import gp, statespace, utils
17
17
  from pymc_extras.distributions import *
18
18
  from pymc_extras.inference.fit import fit
19
- from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
19
+ from pymc_extras.model.marginal.marginal_model import (
20
+ MarginalModel,
21
+ marginalize,
22
+ recover_marginals,
23
+ )
20
24
  from pymc_extras.model.model_api import as_model
21
25
  from pymc_extras.version import __version__
22
26
 
@@ -214,8 +214,8 @@ class DiscreteMarkovChain(Distribution):
214
214
  discrete_mc_op = DiscreteMarkovChainRV(
215
215
  inputs=[P_, steps_, init_dist_, state_rng],
216
216
  outputs=[state_next_rng, discrete_mc_],
217
- ndim_supp=1,
218
217
  n_lags=n_lags,
218
+ extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
219
219
  )
220
220
 
221
221
  discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
@@ -1,20 +1,25 @@
1
+ import warnings
2
+
1
3
  from collections.abc import Sequence
2
4
 
3
5
  import numpy as np
4
6
  import pytensor.tensor as pt
5
7
 
6
8
  from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
9
+ from pymc.distributions.distribution import _support_point, support_point
7
10
  from pymc.logprob.abstract import MeasurableOp, _logprob
8
11
  from pymc.logprob.basic import conditional_logp, logp
9
12
  from pymc.pytensorf import constant_fold
10
13
  from pytensor import Variable
11
14
  from pytensor.compile.builders import OpFromGraph
12
15
  from pytensor.compile.mode import Mode
13
- from pytensor.graph import Op, vectorize_graph
16
+ from pytensor.graph import FunctionGraph, Op, vectorize_graph
17
+ from pytensor.graph.basic import equal_computations
14
18
  from pytensor.graph.replace import clone_replace, graph_replace
15
19
  from pytensor.scan import map as scan_map
16
20
  from pytensor.scan import scan
17
21
  from pytensor.tensor import TensorVariable
22
+ from pytensor.tensor.random.type import RandomType
18
23
 
19
24
  from pymc_extras.distributions import DiscreteMarkovChain
20
25
 
@@ -22,8 +27,15 @@ from pymc_extras.distributions import DiscreteMarkovChain
22
27
  class MarginalRV(OpFromGraph, MeasurableOp):
23
28
  """Base class for Marginalized RVs"""
24
29
 
25
- def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
30
+ def __init__(
31
+ self,
32
+ *args,
33
+ dims_connections: tuple[tuple[int | None], ...],
34
+ dims: tuple[Variable, ...],
35
+ **kwargs,
36
+ ) -> None:
26
37
  self.dims_connections = dims_connections
38
+ self.dims = dims
27
39
  super().__init__(*args, **kwargs)
28
40
 
29
41
  @property
@@ -43,6 +55,74 @@ class MarginalRV(OpFromGraph, MeasurableOp):
43
55
  )
44
56
  return tuple(support_axes_vars)
45
57
 
58
+ def __eq__(self, other):
59
+ # Just to allow easy testing of equivalent models,
60
+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
61
+ if type(self) is not type(other):
62
+ return False
63
+
64
+ return equal_computations(
65
+ self.inner_outputs,
66
+ other.inner_outputs,
67
+ self.inner_inputs,
68
+ other.inner_inputs,
69
+ )
70
+
71
+ def __hash__(self):
72
+ # Just to allow easy testing of equivalent models,
73
+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
74
+ return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))
75
+
76
+
77
+ @_support_point.register
78
+ def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
79
+ """Support point for a marginalized RV.
80
+
81
+ The support point of a marginalized RV is the support point of the inner RV,
82
+ conditioned on the marginalized RV taking its support point.
83
+ """
84
+ outputs = rv.owner.outputs
85
+
86
+ inner_rv = op.inner_outputs[outputs.index(rv)]
87
+ marginalized_inner_rv, *other_dependent_inner_rvs = (
88
+ out
89
+ for out in op.inner_outputs
90
+ if out is not inner_rv and not isinstance(out.type, RandomType)
91
+ )
92
+
93
+ # Replace references to inner rvs by the dummy variables (including the marginalized RV)
94
+ # This is necessary because the inner RVs may depend on each other
95
+ marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
96
+ other_dependent_inner_rv_to_dummies = {
97
+ inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
98
+ }
99
+ inner_rv = clone_replace(
100
+ inner_rv,
101
+ replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
102
+ | other_dependent_inner_rv_to_dummies,
103
+ )
104
+
105
+ # Get support point of inner RV and marginalized RV
106
+ inner_rv_support_point = support_point(inner_rv)
107
+ marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)
108
+
109
+ replacements = [
110
+ # Replace the marginalized RV dummy by its support point
111
+ (marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
112
+ # Replace other dependent RVs dummies by the respective outer outputs.
113
+ # PyMC will replace them by their support points later
114
+ *(
115
+ (v, outputs[op.inner_outputs.index(k)])
116
+ for k, v in other_dependent_inner_rv_to_dummies.items()
117
+ ),
118
+ # Replace outer input RVs
119
+ *zip(op.inner_inputs, inputs),
120
+ ]
121
+ fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
122
+ fgraph.replace_all(replacements, import_missing=True)
123
+ [rv_support_point] = fgraph.outputs
124
+ return rv_support_point
125
+
46
126
 
47
127
  class MarginalFiniteDiscreteRV(MarginalRV):
48
128
  """Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132
212
  Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133
213
  the inner graph.
134
214
  """
135
- return clone_replace(
215
+ return graph_replace(
136
216
  op.inner_outputs,
137
217
  replace=tuple(zip(op.inner_inputs, inputs)),
218
+ strict=False,
138
219
  )
139
220
 
140
221
 
222
+ class NonSeparableLogpWarning(UserWarning):
223
+ pass
224
+
225
+
226
+ def warn_non_separable_logp(values):
227
+ if len(values) > 1:
228
+ warnings.warn(
229
+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
230
+ f"Their joint logp terms will be assigned to the first value: {values[0]}.",
231
+ NonSeparableLogpWarning,
232
+ stacklevel=2,
233
+ )
234
+
235
+
141
236
  DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142
237
 
143
238
 
@@ -199,6 +294,7 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
199
294
  # Align logp with non-collapsed batch dimensions of first RV
200
295
  joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
201
296
 
297
+ warn_non_separable_logp(values)
202
298
  # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203
299
  dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204
300
  return joint_logp, *dummy_logps
@@ -272,5 +368,6 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
272
368
 
273
369
  # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274
370
  # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
371
+ warn_non_separable_logp(values)
275
372
  dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
276
373
  return joint_logp, *dummy_logps
@@ -4,8 +4,8 @@ from collections.abc import Sequence
4
4
  from itertools import zip_longest
5
5
 
6
6
  from pymc import SymbolicRandomVariable
7
- from pytensor.compile import SharedVariable
8
- from pytensor.graph import Constant, Variable, ancestors
7
+ from pymc.model.fgraph import ModelVar
8
+ from pytensor.graph import Variable, ancestors
9
9
  from pytensor.graph.basic import io_toposort
10
10
  from pytensor.tensor import TensorType, TensorVariable
11
11
  from pytensor.tensor.blockwise import Blockwise
@@ -35,13 +35,9 @@ def static_shape_ancestors(vars):
35
35
 
36
36
  def find_conditional_input_rvs(output_rvs, all_rvs):
37
37
  """Find conditionally indepedent input RVs."""
38
- blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39
- blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
40
- return [
41
- var
42
- for var in ancestors(output_rvs, blockers=blockers)
43
- if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
44
- ]
38
+ other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
39
+ blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
40
+ return [var for var in ancestors(output_rvs, blockers=blockers) if var in other_rvs]
45
41
 
46
42
 
47
43
  def is_conditional_dependent(
@@ -141,6 +137,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
141
137
  # None of the inputs are related to the batch_axes of the input_vars
142
138
  continue
143
139
 
140
+ elif isinstance(node.op, ModelVar):
141
+ var_dims[node.outputs[0]] = inputs_dims[0]
142
+
144
143
  elif isinstance(node.op, DimShuffle):
145
144
  [input_dims] = inputs_dims
146
145
  output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)