pymc-extras 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/distributions/timeseries.py +1 -1
- pymc_extras/inference/fit.py +0 -4
- pymc_extras/inference/pathfinder/__init__.py +3 -0
- pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
- pymc_extras/inference/pathfinder/lbfgs.py +190 -0
- pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
- pymc_extras/model/marginal/distributions.py +100 -3
- pymc_extras/model/marginal/graph_analysis.py +8 -9
- pymc_extras/model/marginal/marginal_model.py +437 -424
- pymc_extras/model/model_api.py +18 -2
- pymc_extras/statespace/core/statespace.py +79 -36
- pymc_extras/statespace/models/structural.py +21 -6
- pymc_extras/utils/model_equivalence.py +66 -0
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/METADATA +15 -5
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/RECORD +28 -24
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/WHEEL +1 -1
- tests/model/marginal/test_distributions.py +12 -11
- tests/model/marginal/test_marginal_model.py +301 -201
- tests/model/test_model_api.py +9 -0
- tests/statespace/test_statespace.py +54 -0
- tests/statespace/test_structural.py +10 -3
- tests/test_pathfinder.py +135 -7
- tests/test_pivoted_cholesky.py +1 -1
- tests/utils.py +0 -31
- pymc_extras/inference/pathfinder.py +0 -134
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.2.dist-info}/top_level.txt +0 -0
pymc_extras/model/model_api.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from functools import wraps
|
|
2
|
+
from inspect import signature
|
|
2
3
|
|
|
3
|
-
|
|
4
|
+
import pytensor.tensor as pt
|
|
5
|
+
|
|
6
|
+
from pymc import Data, Model
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
def as_model(*model_args, **model_kwargs):
|
|
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
|
|
|
9
12
|
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
|
|
10
13
|
Additionally, a coords argument is added to the function so coords can be changed during function invocation
|
|
11
14
|
|
|
15
|
+
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
|
|
16
|
+
|
|
12
17
|
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
|
|
13
18
|
|
|
14
19
|
Examples
|
|
@@ -47,8 +52,19 @@ def as_model(*model_args, **model_kwargs):
|
|
|
47
52
|
@wraps(f)
|
|
48
53
|
def make_model(*args, **kwargs):
|
|
49
54
|
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
|
|
55
|
+
sig = signature(f)
|
|
56
|
+
ba = sig.bind(*args, **kwargs)
|
|
57
|
+
ba.apply_defaults()
|
|
58
|
+
|
|
50
59
|
with Model(*model_args, coords=coords, **model_kwargs) as m:
|
|
51
|
-
|
|
60
|
+
for name, v in ba.arguments.items():
|
|
61
|
+
# Only wrap pm.Data around values pytensor can process
|
|
62
|
+
try:
|
|
63
|
+
_ = pt.as_tensor_variable(v)
|
|
64
|
+
ba.arguments[name] = Data(name, v)
|
|
65
|
+
except (NotImplementedError, TypeError, ValueError):
|
|
66
|
+
pass
|
|
67
|
+
f(*ba.args, **ba.kwargs)
|
|
52
68
|
return m
|
|
53
69
|
|
|
54
70
|
return make_model
|
|
@@ -15,6 +15,9 @@ from pymc.model.transform.optimization import freeze_dims_and_data
|
|
|
15
15
|
from pymc.util import RandomState
|
|
16
16
|
from pytensor import Variable, graph_replace
|
|
17
17
|
from pytensor.compile import get_mode
|
|
18
|
+
from rich.box import SIMPLE_HEAD
|
|
19
|
+
from rich.console import Console
|
|
20
|
+
from rich.table import Table
|
|
18
21
|
|
|
19
22
|
from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
20
23
|
from pymc_extras.statespace.filters import (
|
|
@@ -254,53 +257,72 @@ class PyMCStateSpace:
|
|
|
254
257
|
self.kalman_smoother = KalmanSmoother()
|
|
255
258
|
self.make_symbolic_graph()
|
|
256
259
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def _print_prior_requirements(self) -> None:
|
|
270
|
-
"""
|
|
271
|
-
Prints a short report to the terminal about the priors needed for the model, including their names,
|
|
260
|
+
self.requirement_table = None
|
|
261
|
+
self._populate_prior_requirements()
|
|
262
|
+
self._populate_data_requirements()
|
|
263
|
+
|
|
264
|
+
if verbose and self.requirement_table:
|
|
265
|
+
console = Console()
|
|
266
|
+
console.print(self.requirement_table)
|
|
267
|
+
|
|
268
|
+
def _populate_prior_requirements(self) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Add requirements about priors needed for the model to a rich table, including their names,
|
|
272
271
|
shapes, named dimensions, and any parameter constraints.
|
|
273
272
|
"""
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
273
|
+
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
|
|
274
|
+
# is not true.
|
|
275
|
+
try:
|
|
276
|
+
if not isinstance(self.param_info, dict):
|
|
277
|
+
return
|
|
278
|
+
except NotImplementedError:
|
|
279
|
+
return
|
|
278
280
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
281
|
+
if self.requirement_table is None:
|
|
282
|
+
self._initialize_requirement_table()
|
|
283
|
+
|
|
284
|
+
for param, info in self.param_info.items():
|
|
285
|
+
self.requirement_table.add_row(
|
|
286
|
+
param, str(info["shape"]), info["constraints"], str(info["dims"])
|
|
287
|
+
)
|
|
284
288
|
|
|
285
|
-
def
|
|
289
|
+
def _populate_data_requirements(self) -> None:
|
|
286
290
|
"""
|
|
287
|
-
|
|
288
|
-
and named dimensions.
|
|
291
|
+
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
|
|
289
292
|
"""
|
|
290
|
-
|
|
293
|
+
try:
|
|
294
|
+
if not isinstance(self.data_info, dict):
|
|
295
|
+
return
|
|
296
|
+
except NotImplementedError:
|
|
291
297
|
return
|
|
292
298
|
|
|
293
|
-
|
|
299
|
+
if self.requirement_table is None:
|
|
300
|
+
self._initialize_requirement_table()
|
|
301
|
+
else:
|
|
302
|
+
self.requirement_table.add_section()
|
|
303
|
+
|
|
294
304
|
for data, info in self.data_info.items():
|
|
295
|
-
|
|
296
|
-
|
|
305
|
+
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
|
|
306
|
+
|
|
307
|
+
def _initialize_requirement_table(self) -> None:
|
|
308
|
+
self.requirement_table = Table(
|
|
309
|
+
show_header=True,
|
|
310
|
+
show_edge=True,
|
|
311
|
+
box=SIMPLE_HEAD,
|
|
312
|
+
highlight=True,
|
|
313
|
+
)
|
|
297
314
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
315
|
+
self.requirement_table.title = "Model Requirements"
|
|
316
|
+
self.requirement_table.caption = (
|
|
317
|
+
"These parameters should be assigned priors inside a PyMC model block before "
|
|
318
|
+
"calling the build_statespace_graph method."
|
|
302
319
|
)
|
|
303
320
|
|
|
321
|
+
self.requirement_table.add_column("Variable", justify="left")
|
|
322
|
+
self.requirement_table.add_column("Shape", justify="left")
|
|
323
|
+
self.requirement_table.add_column("Constraints", justify="left")
|
|
324
|
+
self.requirement_table.add_column("Dimensions", justify="right")
|
|
325
|
+
|
|
304
326
|
def _unpack_statespace_with_placeholders(
|
|
305
327
|
self,
|
|
306
328
|
) -> tuple[
|
|
@@ -961,10 +983,31 @@ class PyMCStateSpace:
|
|
|
961
983
|
list[pm.Flat]
|
|
962
984
|
A list of pm.Flat variables representing all parameters estimated by the model.
|
|
963
985
|
"""
|
|
986
|
+
|
|
987
|
+
def infer_variable_shape(name):
|
|
988
|
+
shape = self._name_to_variable[name].type.shape
|
|
989
|
+
if not any(dim is None for dim in shape):
|
|
990
|
+
return shape
|
|
991
|
+
|
|
992
|
+
dim_names = self._fit_dims.get(name, None)
|
|
993
|
+
if dim_names is None:
|
|
994
|
+
raise ValueError(
|
|
995
|
+
f"Could not infer shape for {name}, because it was not given coords during model"
|
|
996
|
+
f"fitting"
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
|
|
1000
|
+
return tuple(
|
|
1001
|
+
[
|
|
1002
|
+
shape[i] if shape[i] is not None else shape_from_coords[i]
|
|
1003
|
+
for i in range(len(shape))
|
|
1004
|
+
]
|
|
1005
|
+
)
|
|
1006
|
+
|
|
964
1007
|
for name in self.param_names:
|
|
965
1008
|
pm.Flat(
|
|
966
1009
|
name,
|
|
967
|
-
shape=
|
|
1010
|
+
shape=infer_variable_shape(name),
|
|
968
1011
|
dims=self._fit_dims.get(name, None),
|
|
969
1012
|
)
|
|
970
1013
|
|
|
@@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):
|
|
|
1071
1071
|
|
|
1072
1072
|
If None, states will be numbered ``[State_0, ..., State_s]``
|
|
1073
1073
|
|
|
1074
|
+
remove_first_state: bool, default True
|
|
1075
|
+
If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
|
|
1076
|
+
freedom in the seasonal component, and one state is not identified. If False, the first state will be
|
|
1077
|
+
included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
|
|
1078
|
+
ZeroSumNormal).
|
|
1079
|
+
|
|
1074
1080
|
Notes
|
|
1075
1081
|
-----
|
|
1076
1082
|
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
|
|
@@ -1163,7 +1169,7 @@ class TimeSeasonality(Component):
|
|
|
1163
1169
|
innovations: bool = True,
|
|
1164
1170
|
name: str | None = None,
|
|
1165
1171
|
state_names: list | None = None,
|
|
1166
|
-
|
|
1172
|
+
remove_first_state: bool = True,
|
|
1167
1173
|
):
|
|
1168
1174
|
if name is None:
|
|
1169
1175
|
name = f"Seasonal[s={season_length}]"
|
|
@@ -1176,14 +1182,15 @@ class TimeSeasonality(Component):
|
|
|
1176
1182
|
)
|
|
1177
1183
|
state_names = state_names.copy()
|
|
1178
1184
|
self.innovations = innovations
|
|
1179
|
-
self.
|
|
1185
|
+
self.remove_first_state = remove_first_state
|
|
1180
1186
|
|
|
1181
|
-
if self.
|
|
1187
|
+
if self.remove_first_state:
|
|
1182
1188
|
# In traditional models, the first state isn't identified, so we can help out the user by automatically
|
|
1183
1189
|
# discarding it.
|
|
1184
1190
|
# TODO: Can this be stashed and reconstructed automatically somehow?
|
|
1185
1191
|
state_names.pop(0)
|
|
1186
|
-
|
|
1192
|
+
|
|
1193
|
+
k_states = season_length - int(self.remove_first_state)
|
|
1187
1194
|
|
|
1188
1195
|
super().__init__(
|
|
1189
1196
|
name=name,
|
|
@@ -1218,8 +1225,16 @@ class TimeSeasonality(Component):
|
|
|
1218
1225
|
self.shock_names = [f"{self.name}"]
|
|
1219
1226
|
|
|
1220
1227
|
def make_symbolic_graph(self) -> None:
|
|
1221
|
-
|
|
1222
|
-
|
|
1228
|
+
if self.remove_first_state:
|
|
1229
|
+
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
|
|
1230
|
+
# all previous states.
|
|
1231
|
+
T = np.eye(self.k_states, k=-1)
|
|
1232
|
+
T[0, :] = -1
|
|
1233
|
+
else:
|
|
1234
|
+
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
|
|
1235
|
+
# circulant matrix that cycles between the states.
|
|
1236
|
+
T = np.eye(self.k_states, k=1)
|
|
1237
|
+
T[-1, 0] = 1
|
|
1223
1238
|
|
|
1224
1239
|
self.ssm["transition", :, :] = T
|
|
1225
1240
|
self.ssm["design", 0, 0] = 1
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from pymc.model.core import Model
|
|
4
|
+
from pymc.model.fgraph import fgraph_from_model
|
|
5
|
+
from pytensor import Variable
|
|
6
|
+
from pytensor.compile import SharedVariable
|
|
7
|
+
from pytensor.graph import Constant, graph_inputs
|
|
8
|
+
from pytensor.graph.basic import equal_computations
|
|
9
|
+
from pytensor.tensor.random.type import RandomType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def equal_computations_up_to_root(
|
|
13
|
+
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
|
|
14
|
+
) -> bool:
|
|
15
|
+
# Check if graphs are equivalent even if root variables have distinct identities
|
|
16
|
+
|
|
17
|
+
x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
|
|
18
|
+
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
|
|
19
|
+
if len(x_graph_inputs) != len(y_graph_inputs):
|
|
20
|
+
return False
|
|
21
|
+
for x, y in zip(x_graph_inputs, y_graph_inputs):
|
|
22
|
+
if x.type != y.type:
|
|
23
|
+
return False
|
|
24
|
+
if x.name != y.name:
|
|
25
|
+
return False
|
|
26
|
+
if isinstance(x, SharedVariable):
|
|
27
|
+
# if not isinstance(y, SharedVariable):
|
|
28
|
+
# return False
|
|
29
|
+
if isinstance(x.type, RandomType) and ignore_rng_values:
|
|
30
|
+
continue
|
|
31
|
+
if not x.type.values_eq(x.get_value(), y.get_value()):
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def equivalent_models(model1: Model, model2: Model) -> bool:
|
|
38
|
+
"""Check whether two PyMC models are equivalent.
|
|
39
|
+
|
|
40
|
+
Examples
|
|
41
|
+
--------
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
import pymc as pm
|
|
46
|
+
from pymc_extras.utils.model_equivalence import equivalent_models
|
|
47
|
+
|
|
48
|
+
with pm.Model() as m1:
|
|
49
|
+
x = pm.Normal("x")
|
|
50
|
+
y = pm.Normal("y", x)
|
|
51
|
+
|
|
52
|
+
with pm.Model() as m2:
|
|
53
|
+
x = pm.Normal("x")
|
|
54
|
+
y = pm.Normal("y", x + 1)
|
|
55
|
+
|
|
56
|
+
with pm.Model() as m3:
|
|
57
|
+
x = pm.Normal("x")
|
|
58
|
+
y = pm.Normal("y", x)
|
|
59
|
+
|
|
60
|
+
assert not equivalent_models(m1, m2)
|
|
61
|
+
assert equivalent_models(m1, m3)
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
fgraph1, _ = fgraph_from_model(model1)
|
|
65
|
+
fgraph2, _ = fgraph_from_model(model2)
|
|
66
|
+
return equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs)
|
pymc_extras/version.txt
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.2.
|
|
1
|
+
0.2.2
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -34,6 +34,17 @@ Provides-Extra: dev
|
|
|
34
34
|
Requires-Dist: dask[all]; extra == "dev"
|
|
35
35
|
Requires-Dist: blackjax; extra == "dev"
|
|
36
36
|
Requires-Dist: statsmodels; extra == "dev"
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: license
|
|
42
|
+
Dynamic: maintainer
|
|
43
|
+
Dynamic: maintainer-email
|
|
44
|
+
Dynamic: provides-extra
|
|
45
|
+
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
47
|
+
Dynamic: summary
|
|
37
48
|
|
|
38
49
|
# Welcome to `pymc-extras`
|
|
39
50
|
<a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
|
|
@@ -63,10 +74,9 @@ import pymc as pm
|
|
|
63
74
|
import pymc_extras as pmx
|
|
64
75
|
|
|
65
76
|
with pm.Model():
|
|
77
|
+
alpha = pmx.ParabolicFractal('alpha', b=1, c=1)
|
|
66
78
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
...
|
|
79
|
+
...
|
|
70
80
|
|
|
71
81
|
```
|
|
72
82
|
|
|
@@ -1,31 +1,34 @@
|
|
|
1
|
-
pymc_extras/__init__.py,sha256=
|
|
1
|
+
pymc_extras/__init__.py,sha256=URh185f6b1xp2Taj2W2NJuW_hErKufBcLeQ0WDCyaNk,1160
|
|
2
2
|
pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=mY9riH7Xpu9E6EIQ0CN7cVvtQupyVPSNxBiv7ApIQQM,6
|
|
7
7
|
pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
10
10
|
pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
|
|
11
|
-
pymc_extras/distributions/timeseries.py,sha256=
|
|
11
|
+
pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
|
|
12
12
|
pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
|
|
13
13
|
pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
|
|
14
14
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
15
15
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
16
16
|
pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
|
|
17
17
|
pymc_extras/inference/find_map.py,sha256=T0uO8prUI5aBNuR1AN8fbA4cHmLRQLXznwJrfxfe7CA,15723
|
|
18
|
-
pymc_extras/inference/fit.py,sha256=
|
|
18
|
+
pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
|
|
19
19
|
pymc_extras/inference/laplace.py,sha256=OglOvnxfHLe0VXxBC1-ddVzADR9zgGxUPScM6P6FYo8,21163
|
|
20
|
-
pymc_extras/inference/pathfinder.py,sha256=
|
|
20
|
+
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
21
|
+
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
|
|
22
|
+
pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
|
|
23
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
|
|
21
24
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
22
25
|
pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
|
|
23
26
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
-
pymc_extras/model/model_api.py,sha256=
|
|
27
|
+
pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
|
|
25
28
|
pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
-
pymc_extras/model/marginal/distributions.py,sha256=
|
|
27
|
-
pymc_extras/model/marginal/graph_analysis.py,sha256=
|
|
28
|
-
pymc_extras/model/marginal/marginal_model.py,sha256=
|
|
29
|
+
pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
|
|
30
|
+
pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
|
|
31
|
+
pymc_extras/model/marginal/marginal_model.py,sha256=oNsiSWHjOPCTDxNEivEILLP_cOuBarm29Gr2p6hWHIM,23594
|
|
29
32
|
pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
33
|
pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
|
|
31
34
|
pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -34,7 +37,7 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
|
|
|
34
37
|
pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
|
|
35
38
|
pymc_extras/statespace/core/compile.py,sha256=1c8Q9D9zeUe7F0z7CH6q1C6ZuLg2_imgk8RoE_KMaFI,1608
|
|
36
39
|
pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
|
|
37
|
-
pymc_extras/statespace/core/statespace.py,sha256=
|
|
40
|
+
pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
|
|
38
41
|
pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
|
|
39
42
|
pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
|
|
40
43
|
pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
|
|
@@ -44,7 +47,7 @@ pymc_extras/statespace/models/ETS.py,sha256=o039M-6aCxyMXbbKvUeNVZhheCKvvNIAmuj0
|
|
|
44
47
|
pymc_extras/statespace/models/SARIMAX.py,sha256=SX0eiSK1pOt4dHBjWzBqVpRz67pBGLN5pQQgXcOiOgY,21607
|
|
45
48
|
pymc_extras/statespace/models/VARMAX.py,sha256=xkIuftNc_5NHFpqZalExni99-1kovnzm5OjMIDNgaxY,15989
|
|
46
49
|
pymc_extras/statespace/models/__init__.py,sha256=U79b8rTHBNijVvvGOd43nLu4PCloPUH1rwlN87-n88c,317
|
|
47
|
-
pymc_extras/statespace/models/structural.py,sha256=
|
|
50
|
+
pymc_extras/statespace/models/structural.py,sha256=sep9pesJdRN4X8Bea6_RhO3112uWOZRuYRxO6ibl_OA,63943
|
|
48
51
|
pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5JI1fJQM8N7EnE,15373
|
|
49
52
|
pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
50
53
|
pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
|
|
@@ -52,6 +55,7 @@ pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
|
|
|
52
55
|
pymc_extras/statespace/utils/data_tools.py,sha256=caanvrxDu9g-dEKff2bbmaVTs6-71kkSoYIiiSUXhw4,5985
|
|
53
56
|
pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
|
|
54
57
|
pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
|
|
58
|
+
pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
|
|
55
59
|
pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
|
|
56
60
|
pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
|
|
57
61
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
@@ -62,23 +66,23 @@ tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6
|
|
|
62
66
|
tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
|
|
63
67
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
64
68
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
65
|
-
tests/test_pathfinder.py,sha256=
|
|
66
|
-
tests/test_pivoted_cholesky.py,sha256=
|
|
69
|
+
tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
|
|
70
|
+
tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
|
|
67
71
|
tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
|
|
68
72
|
tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
|
|
69
73
|
tests/test_splines.py,sha256=xSZi4hqqReN1H8LHr0xjDmpomhDQm8auIsWQjFOyjbM,2608
|
|
70
|
-
tests/utils.py,sha256=
|
|
74
|
+
tests/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
71
75
|
tests/distributions/__init__.py,sha256=jt-oloszTLNFwi9AgU3M4m6xKQ8xpQE338rmmaMZcMs,795
|
|
72
76
|
tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9oflBfqCaFo,5477
|
|
73
77
|
tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
|
|
74
78
|
tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
|
|
75
79
|
tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
|
|
76
80
|
tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
|
-
tests/model/test_model_api.py,sha256=
|
|
81
|
+
tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
|
|
78
82
|
tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
79
|
-
tests/model/marginal/test_distributions.py,sha256=
|
|
83
|
+
tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
|
|
80
84
|
tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
|
|
81
|
-
tests/model/marginal/test_marginal_model.py,sha256=
|
|
85
|
+
tests/model/marginal/test_marginal_model.py,sha256=uOmARalkdWq3sDbnJQ0KjiLwviqauZOAnafmYS_Cnd8,35475
|
|
82
86
|
tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
83
87
|
tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
|
|
84
88
|
tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
|
|
@@ -87,15 +91,15 @@ tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Z
|
|
|
87
91
|
tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
|
|
88
92
|
tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
|
|
89
93
|
tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
|
|
90
|
-
tests/statespace/test_statespace.py,sha256=
|
|
94
|
+
tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
|
|
91
95
|
tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
|
|
92
|
-
tests/statespace/test_structural.py,sha256=
|
|
96
|
+
tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
|
|
93
97
|
tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
94
98
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
95
99
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
96
100
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
97
|
-
pymc_extras-0.2.
|
|
98
|
-
pymc_extras-0.2.
|
|
99
|
-
pymc_extras-0.2.
|
|
100
|
-
pymc_extras-0.2.
|
|
101
|
-
pymc_extras-0.2.
|
|
101
|
+
pymc_extras-0.2.2.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
102
|
+
pymc_extras-0.2.2.dist-info/METADATA,sha256=9k60kKKNzr7E24gACpTBNOWj-tRTNOiujSZOFD89G5c,5140
|
|
103
|
+
pymc_extras-0.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
104
|
+
pymc_extras-0.2.2.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
105
|
+
pymc_extras-0.2.2.dist-info/RECORD,,
|
|
@@ -6,7 +6,7 @@ from pymc.logprob.abstract import _logprob
|
|
|
6
6
|
from pytensor import tensor as pt
|
|
7
7
|
from scipy.stats import norm
|
|
8
8
|
|
|
9
|
-
from pymc_extras import
|
|
9
|
+
from pymc_extras import marginalize
|
|
10
10
|
from pymc_extras.distributions import DiscreteMarkovChain
|
|
11
11
|
from pymc_extras.model.marginal.distributions import MarginalFiniteDiscreteRV
|
|
12
12
|
|
|
@@ -21,6 +21,7 @@ def test_marginalized_bernoulli_logp():
|
|
|
21
21
|
[mu],
|
|
22
22
|
[idx, y],
|
|
23
23
|
dims_connections=(((),),),
|
|
24
|
+
dims=(),
|
|
24
25
|
)(mu)[0].owner
|
|
25
26
|
|
|
26
27
|
y_vv = y.clone()
|
|
@@ -43,7 +44,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
|
43
44
|
if batch_chain and not batch_emission:
|
|
44
45
|
pytest.skip("Redundant implicit combination")
|
|
45
46
|
|
|
46
|
-
with
|
|
47
|
+
with pm.Model() as m:
|
|
47
48
|
P = [[0, 1], [1, 0]]
|
|
48
49
|
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
49
50
|
chain = DiscreteMarkovChain(
|
|
@@ -53,8 +54,8 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
|
53
54
|
"emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None
|
|
54
55
|
)
|
|
55
56
|
|
|
56
|
-
|
|
57
|
-
logp_fn =
|
|
57
|
+
marginal_m = marginalize(m, [chain])
|
|
58
|
+
logp_fn = marginal_m.compile_logp()
|
|
58
59
|
|
|
59
60
|
test_value = np.array([-1, 1, -1, 1])
|
|
60
61
|
expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval()
|
|
@@ -70,7 +71,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
|
|
|
70
71
|
)
|
|
71
72
|
def test_marginalized_hmm_categorical_emission(categorical_emission):
|
|
72
73
|
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""
|
|
73
|
-
with
|
|
74
|
+
with pm.Model() as m:
|
|
74
75
|
P = np.array([[0.5, 0.5], [0.3, 0.7]])
|
|
75
76
|
init_dist = pm.Categorical.dist(p=[0.375, 0.625])
|
|
76
77
|
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
|
|
@@ -78,11 +79,11 @@ def test_marginalized_hmm_categorical_emission(categorical_emission):
|
|
|
78
79
|
emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain])
|
|
79
80
|
else:
|
|
80
81
|
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
|
|
81
|
-
|
|
82
|
+
marginal_m = marginalize(m, [chain])
|
|
82
83
|
|
|
83
84
|
test_value = np.array([0, 0, 1])
|
|
84
85
|
expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video
|
|
85
|
-
logp_fn =
|
|
86
|
+
logp_fn = marginal_m.compile_logp()
|
|
86
87
|
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
|
|
87
88
|
|
|
88
89
|
|
|
@@ -95,7 +96,7 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
|
|
|
95
96
|
(2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
|
|
96
97
|
)
|
|
97
98
|
emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
|
|
98
|
-
with
|
|
99
|
+
with pm.Model() as m:
|
|
99
100
|
P = [[0, 1], [1, 0]]
|
|
100
101
|
init_dist = pm.Categorical.dist(p=[1, 0])
|
|
101
102
|
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
|
|
@@ -108,10 +109,10 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
|
|
|
108
109
|
emission2_mu = emission2_mu[..., None]
|
|
109
110
|
emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape)
|
|
110
111
|
|
|
111
|
-
|
|
112
|
-
m.marginalize([chain])
|
|
112
|
+
marginal_m = marginalize(m, [chain])
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
with pytest.warns(UserWarning, match="multiple dependent variables"):
|
|
115
|
+
logp_fn = marginal_m.compile_logp(sum=False)
|
|
115
116
|
|
|
116
117
|
test_value = np.array([-1, 1, -1, 1])
|
|
117
118
|
multiplier = 2 + batch_emission1 + batch_emission2
|