pymc-extras 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/PKG-INFO +3 -4
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/README.md +2 -3
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/__init__.py +5 -1
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/timeseries.py +1 -1
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/marginal/distributions.py +100 -3
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/marginal/graph_analysis.py +8 -9
- pymc_extras-0.2.1/pymc_extras/model/marginal/marginal_model.py +608 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/models/structural.py +21 -6
- pymc_extras-0.2.0/tests/utils.py → pymc_extras-0.2.1/pymc_extras/utils/model_equivalence.py +38 -3
- pymc_extras-0.2.1/pymc_extras/version.txt +1 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras.egg-info/PKG-INFO +3 -4
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras.egg-info/SOURCES.txt +1 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/model/marginal/test_distributions.py +12 -11
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/model/marginal/test_marginal_model.py +301 -201
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_structural.py +10 -3
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_pivoted_cholesky.py +1 -1
- pymc_extras-0.2.1/tests/utils.py +0 -0
- pymc_extras-0.2.0/pymc_extras/model/marginal/marginal_model.py +0 -595
- pymc_extras-0.2.0/pymc_extras/version.txt +0 -1
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/CONTRIBUTING.md +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/LICENSE +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/MANIFEST.in +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/continuous.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/histogram_utils.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/find_map.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/fit.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/laplace.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/pathfinder.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/inference/smc/sampling.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/linearmodel.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/model_api.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/transforms/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/model_builder.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/preprocessing/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/printing.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/core/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/core/compile.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/core/representation.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/core/statespace.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/filters/distributions.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/filters/utilities.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/models/ETS.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/models/SARIMAX.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/models/VARMAX.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/models/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/models/utilities.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/utils/constants.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/statespace/utils/data_tools.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/utils/pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/utils/prior.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/utils/spline.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras/version.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras.egg-info/dependency_links.txt +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras.egg-info/requires.txt +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pymc_extras.egg-info/top_level.txt +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/pyproject.toml +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/requirements-dev.txt +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/requirements-docs.txt +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/requirements.txt +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/setup.cfg +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/setup.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/distributions/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/distributions/test_continuous.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/distributions/test_discrete.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/model/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/model/marginal/test_graph_analysis.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/model/test_model_api.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_ETS.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_SARIMAX.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_VARMAX.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_coord_assignment.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_distributions.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_kalman_filter.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_representation.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_statespace.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/test_statespace_JAX.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/utilities/__init__.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/utilities/shared_fixtures.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/statespace/utilities/test_helpers.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_find_map.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_histogram_approximation.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_laplace.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_pathfinder.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_printing.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_prior_from_trace.py +0 -0
- {pymc_extras-0.2.0 → pymc_extras-0.2.1}/tests/test_splines.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -63,10 +63,9 @@ import pymc as pm
|
|
|
63
63
|
import pymc_extras as pmx
|
|
64
64
|
|
|
65
65
|
with pm.Model():
|
|
66
|
+
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
|
|
66
67
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
...
|
|
68
|
+
...
|
|
70
69
|
|
|
71
70
|
```
|
|
72
71
|
|
|
@@ -16,7 +16,11 @@ import logging
|
|
|
16
16
|
from pymc_extras import gp, statespace, utils
|
|
17
17
|
from pymc_extras.distributions import *
|
|
18
18
|
from pymc_extras.inference.fit import fit
|
|
19
|
-
from pymc_extras.model.marginal.marginal_model import
|
|
19
|
+
from pymc_extras.model.marginal.marginal_model import (
|
|
20
|
+
MarginalModel,
|
|
21
|
+
marginalize,
|
|
22
|
+
recover_marginals,
|
|
23
|
+
)
|
|
20
24
|
from pymc_extras.model.model_api import as_model
|
|
21
25
|
from pymc_extras.version import __version__
|
|
22
26
|
|
|
@@ -214,8 +214,8 @@ class DiscreteMarkovChain(Distribution):
|
|
|
214
214
|
discrete_mc_op = DiscreteMarkovChainRV(
|
|
215
215
|
inputs=[P_, steps_, init_dist_, state_rng],
|
|
216
216
|
outputs=[state_next_rng, discrete_mc_],
|
|
217
|
-
ndim_supp=1,
|
|
218
217
|
n_lags=n_lags,
|
|
218
|
+
extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
|
|
219
219
|
)
|
|
220
220
|
|
|
221
221
|
discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
|
|
@@ -1,20 +1,25 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
1
3
|
from collections.abc import Sequence
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
6
|
import pytensor.tensor as pt
|
|
5
7
|
|
|
6
8
|
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
|
|
9
|
+
from pymc.distributions.distribution import _support_point, support_point
|
|
7
10
|
from pymc.logprob.abstract import MeasurableOp, _logprob
|
|
8
11
|
from pymc.logprob.basic import conditional_logp, logp
|
|
9
12
|
from pymc.pytensorf import constant_fold
|
|
10
13
|
from pytensor import Variable
|
|
11
14
|
from pytensor.compile.builders import OpFromGraph
|
|
12
15
|
from pytensor.compile.mode import Mode
|
|
13
|
-
from pytensor.graph import Op, vectorize_graph
|
|
16
|
+
from pytensor.graph import FunctionGraph, Op, vectorize_graph
|
|
17
|
+
from pytensor.graph.basic import equal_computations
|
|
14
18
|
from pytensor.graph.replace import clone_replace, graph_replace
|
|
15
19
|
from pytensor.scan import map as scan_map
|
|
16
20
|
from pytensor.scan import scan
|
|
17
21
|
from pytensor.tensor import TensorVariable
|
|
22
|
+
from pytensor.tensor.random.type import RandomType
|
|
18
23
|
|
|
19
24
|
from pymc_extras.distributions import DiscreteMarkovChain
|
|
20
25
|
|
|
@@ -22,8 +27,15 @@ from pymc_extras.distributions import DiscreteMarkovChain
|
|
|
22
27
|
class MarginalRV(OpFromGraph, MeasurableOp):
|
|
23
28
|
"""Base class for Marginalized RVs"""
|
|
24
29
|
|
|
25
|
-
def __init__(
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*args,
|
|
33
|
+
dims_connections: tuple[tuple[int | None], ...],
|
|
34
|
+
dims: tuple[Variable, ...],
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> None:
|
|
26
37
|
self.dims_connections = dims_connections
|
|
38
|
+
self.dims = dims
|
|
27
39
|
super().__init__(*args, **kwargs)
|
|
28
40
|
|
|
29
41
|
@property
|
|
@@ -43,6 +55,74 @@ class MarginalRV(OpFromGraph, MeasurableOp):
|
|
|
43
55
|
)
|
|
44
56
|
return tuple(support_axes_vars)
|
|
45
57
|
|
|
58
|
+
def __eq__(self, other):
|
|
59
|
+
# Just to allow easy testing of equivalent models,
|
|
60
|
+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
|
|
61
|
+
if type(self) is not type(other):
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
return equal_computations(
|
|
65
|
+
self.inner_outputs,
|
|
66
|
+
other.inner_outputs,
|
|
67
|
+
self.inner_inputs,
|
|
68
|
+
other.inner_inputs,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def __hash__(self):
|
|
72
|
+
# Just to allow easy testing of equivalent models,
|
|
73
|
+
# This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
|
|
74
|
+
return hash((type(self), len(self.inner_inputs), len(self.inner_outputs)))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@_support_point.register
|
|
78
|
+
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
|
|
79
|
+
"""Support point for a marginalized RV.
|
|
80
|
+
|
|
81
|
+
The support point of a marginalized RV is the support point of the inner RV,
|
|
82
|
+
conditioned on the marginalized RV taking its support point.
|
|
83
|
+
"""
|
|
84
|
+
outputs = rv.owner.outputs
|
|
85
|
+
|
|
86
|
+
inner_rv = op.inner_outputs[outputs.index(rv)]
|
|
87
|
+
marginalized_inner_rv, *other_dependent_inner_rvs = (
|
|
88
|
+
out
|
|
89
|
+
for out in op.inner_outputs
|
|
90
|
+
if out is not inner_rv and not isinstance(out.type, RandomType)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Replace references to inner rvs by the dummy variables (including the marginalized RV)
|
|
94
|
+
# This is necessary because the inner RVs may depend on each other
|
|
95
|
+
marginalized_inner_rv_dummy = marginalized_inner_rv.clone()
|
|
96
|
+
other_dependent_inner_rv_to_dummies = {
|
|
97
|
+
inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs
|
|
98
|
+
}
|
|
99
|
+
inner_rv = clone_replace(
|
|
100
|
+
inner_rv,
|
|
101
|
+
replace={marginalized_inner_rv: marginalized_inner_rv_dummy}
|
|
102
|
+
| other_dependent_inner_rv_to_dummies,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Get support point of inner RV and marginalized RV
|
|
106
|
+
inner_rv_support_point = support_point(inner_rv)
|
|
107
|
+
marginalized_inner_rv_support_point = support_point(marginalized_inner_rv)
|
|
108
|
+
|
|
109
|
+
replacements = [
|
|
110
|
+
# Replace the marginalized RV dummy by its support point
|
|
111
|
+
(marginalized_inner_rv_dummy, marginalized_inner_rv_support_point),
|
|
112
|
+
# Replace other dependent RVs dummies by the respective outer outputs.
|
|
113
|
+
# PyMC will replace them by their support points later
|
|
114
|
+
*(
|
|
115
|
+
(v, outputs[op.inner_outputs.index(k)])
|
|
116
|
+
for k, v in other_dependent_inner_rv_to_dummies.items()
|
|
117
|
+
),
|
|
118
|
+
# Replace outer input RVs
|
|
119
|
+
*zip(op.inner_inputs, inputs),
|
|
120
|
+
]
|
|
121
|
+
fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False)
|
|
122
|
+
fgraph.replace_all(replacements, import_missing=True)
|
|
123
|
+
[rv_support_point] = fgraph.outputs
|
|
124
|
+
return rv_support_point
|
|
125
|
+
|
|
46
126
|
|
|
47
127
|
class MarginalFiniteDiscreteRV(MarginalRV):
|
|
48
128
|
"""Base class for Marginalized Finite Discrete RVs"""
|
|
@@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
|
|
|
132
212
|
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
|
|
133
213
|
the inner graph.
|
|
134
214
|
"""
|
|
135
|
-
return
|
|
215
|
+
return graph_replace(
|
|
136
216
|
op.inner_outputs,
|
|
137
217
|
replace=tuple(zip(op.inner_inputs, inputs)),
|
|
218
|
+
strict=False,
|
|
138
219
|
)
|
|
139
220
|
|
|
140
221
|
|
|
222
|
+
class NonSeparableLogpWarning(UserWarning):
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def warn_non_separable_logp(values):
|
|
227
|
+
if len(values) > 1:
|
|
228
|
+
warnings.warn(
|
|
229
|
+
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
|
|
230
|
+
f"Their joint logp terms will be assigned to the first value: {values[0]}.",
|
|
231
|
+
NonSeparableLogpWarning,
|
|
232
|
+
stacklevel=2,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
141
236
|
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
|
|
142
237
|
|
|
143
238
|
|
|
@@ -199,6 +294,7 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
|
|
|
199
294
|
# Align logp with non-collapsed batch dimensions of first RV
|
|
200
295
|
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
|
|
201
296
|
|
|
297
|
+
warn_non_separable_logp(values)
|
|
202
298
|
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
|
|
203
299
|
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
|
|
204
300
|
return joint_logp, *dummy_logps
|
|
@@ -272,5 +368,6 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
|
|
|
272
368
|
|
|
273
369
|
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
|
|
274
370
|
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
|
|
371
|
+
warn_non_separable_logp(values)
|
|
275
372
|
dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
|
|
276
373
|
return joint_logp, *dummy_logps
|
|
@@ -4,8 +4,8 @@ from collections.abc import Sequence
|
|
|
4
4
|
from itertools import zip_longest
|
|
5
5
|
|
|
6
6
|
from pymc import SymbolicRandomVariable
|
|
7
|
-
from
|
|
8
|
-
from pytensor.graph import
|
|
7
|
+
from pymc.model.fgraph import ModelVar
|
|
8
|
+
from pytensor.graph import Variable, ancestors
|
|
9
9
|
from pytensor.graph.basic import io_toposort
|
|
10
10
|
from pytensor.tensor import TensorType, TensorVariable
|
|
11
11
|
from pytensor.tensor.blockwise import Blockwise
|
|
@@ -35,13 +35,9 @@ def static_shape_ancestors(vars):
|
|
|
35
35
|
|
|
36
36
|
def find_conditional_input_rvs(output_rvs, all_rvs):
|
|
37
37
|
"""Find conditionally indepedent input RVs."""
|
|
38
|
-
|
|
39
|
-
blockers
|
|
40
|
-
return [
|
|
41
|
-
var
|
|
42
|
-
for var in ancestors(output_rvs, blockers=blockers)
|
|
43
|
-
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
|
|
44
|
-
]
|
|
38
|
+
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
|
|
39
|
+
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
|
|
40
|
+
return [var for var in ancestors(output_rvs, blockers=blockers) if var in other_rvs]
|
|
45
41
|
|
|
46
42
|
|
|
47
43
|
def is_conditional_dependent(
|
|
@@ -141,6 +137,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
|
|
|
141
137
|
# None of the inputs are related to the batch_axes of the input_vars
|
|
142
138
|
continue
|
|
143
139
|
|
|
140
|
+
elif isinstance(node.op, ModelVar):
|
|
141
|
+
var_dims[node.outputs[0]] = inputs_dims[0]
|
|
142
|
+
|
|
144
143
|
elif isinstance(node.op, DimShuffle):
|
|
145
144
|
[input_dims] = inputs_dims
|
|
146
145
|
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
|