pymc-extras 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/inference/fit.py +0 -4
- pymc_extras/inference/pathfinder/__init__.py +3 -0
- pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
- pymc_extras/inference/pathfinder/lbfgs.py +190 -0
- pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
- pymc_extras/model/model_api.py +18 -2
- pymc_extras/statespace/core/statespace.py +79 -36
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.2.dist-info}/METADATA +13 -2
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.2.dist-info}/RECORD +16 -13
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.2.dist-info}/WHEEL +1 -1
- tests/model/test_model_api.py +9 -0
- tests/statespace/test_statespace.py +54 -0
- tests/test_pathfinder.py +135 -7
- pymc_extras/inference/pathfinder.py +0 -134
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.2.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.2.dist-info}/top_level.txt +0 -0
pymc_extras/model/model_api.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from functools import wraps
|
|
2
|
+
from inspect import signature
|
|
2
3
|
|
|
3
|
-
|
|
4
|
+
import pytensor.tensor as pt
|
|
5
|
+
|
|
6
|
+
from pymc import Data, Model
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
def as_model(*model_args, **model_kwargs):
|
|
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
|
|
|
9
12
|
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
|
|
10
13
|
Additionally, a coords argument is added to the function so coords can be changed during function invocation
|
|
11
14
|
|
|
15
|
+
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
|
|
16
|
+
|
|
12
17
|
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
|
|
13
18
|
|
|
14
19
|
Examples
|
|
@@ -47,8 +52,19 @@ def as_model(*model_args, **model_kwargs):
|
|
|
47
52
|
@wraps(f)
|
|
48
53
|
def make_model(*args, **kwargs):
|
|
49
54
|
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
|
|
55
|
+
sig = signature(f)
|
|
56
|
+
ba = sig.bind(*args, **kwargs)
|
|
57
|
+
ba.apply_defaults()
|
|
58
|
+
|
|
50
59
|
with Model(*model_args, coords=coords, **model_kwargs) as m:
|
|
51
|
-
|
|
60
|
+
for name, v in ba.arguments.items():
|
|
61
|
+
# Only wrap pm.Data around values pytensor can process
|
|
62
|
+
try:
|
|
63
|
+
_ = pt.as_tensor_variable(v)
|
|
64
|
+
ba.arguments[name] = Data(name, v)
|
|
65
|
+
except (NotImplementedError, TypeError, ValueError):
|
|
66
|
+
pass
|
|
67
|
+
f(*ba.args, **ba.kwargs)
|
|
52
68
|
return m
|
|
53
69
|
|
|
54
70
|
return make_model
|
|
@@ -15,6 +15,9 @@ from pymc.model.transform.optimization import freeze_dims_and_data
|
|
|
15
15
|
from pymc.util import RandomState
|
|
16
16
|
from pytensor import Variable, graph_replace
|
|
17
17
|
from pytensor.compile import get_mode
|
|
18
|
+
from rich.box import SIMPLE_HEAD
|
|
19
|
+
from rich.console import Console
|
|
20
|
+
from rich.table import Table
|
|
18
21
|
|
|
19
22
|
from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
20
23
|
from pymc_extras.statespace.filters import (
|
|
@@ -254,53 +257,72 @@ class PyMCStateSpace:
|
|
|
254
257
|
self.kalman_smoother = KalmanSmoother()
|
|
255
258
|
self.make_symbolic_graph()
|
|
256
259
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def _print_prior_requirements(self) -> None:
|
|
270
|
-
"""
|
|
271
|
-
Prints a short report to the terminal about the priors needed for the model, including their names,
|
|
260
|
+
self.requirement_table = None
|
|
261
|
+
self._populate_prior_requirements()
|
|
262
|
+
self._populate_data_requirements()
|
|
263
|
+
|
|
264
|
+
if verbose and self.requirement_table:
|
|
265
|
+
console = Console()
|
|
266
|
+
console.print(self.requirement_table)
|
|
267
|
+
|
|
268
|
+
def _populate_prior_requirements(self) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Add requirements about priors needed for the model to a rich table, including their names,
|
|
272
271
|
shapes, named dimensions, and any parameter constraints.
|
|
273
272
|
"""
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
273
|
+
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
|
|
274
|
+
# is not true.
|
|
275
|
+
try:
|
|
276
|
+
if not isinstance(self.param_info, dict):
|
|
277
|
+
return
|
|
278
|
+
except NotImplementedError:
|
|
279
|
+
return
|
|
278
280
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
281
|
+
if self.requirement_table is None:
|
|
282
|
+
self._initialize_requirement_table()
|
|
283
|
+
|
|
284
|
+
for param, info in self.param_info.items():
|
|
285
|
+
self.requirement_table.add_row(
|
|
286
|
+
param, str(info["shape"]), info["constraints"], str(info["dims"])
|
|
287
|
+
)
|
|
284
288
|
|
|
285
|
-
def
|
|
289
|
+
def _populate_data_requirements(self) -> None:
|
|
286
290
|
"""
|
|
287
|
-
|
|
288
|
-
and named dimensions.
|
|
291
|
+
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
|
|
289
292
|
"""
|
|
290
|
-
|
|
293
|
+
try:
|
|
294
|
+
if not isinstance(self.data_info, dict):
|
|
295
|
+
return
|
|
296
|
+
except NotImplementedError:
|
|
291
297
|
return
|
|
292
298
|
|
|
293
|
-
|
|
299
|
+
if self.requirement_table is None:
|
|
300
|
+
self._initialize_requirement_table()
|
|
301
|
+
else:
|
|
302
|
+
self.requirement_table.add_section()
|
|
303
|
+
|
|
294
304
|
for data, info in self.data_info.items():
|
|
295
|
-
|
|
296
|
-
|
|
305
|
+
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
|
|
306
|
+
|
|
307
|
+
def _initialize_requirement_table(self) -> None:
|
|
308
|
+
self.requirement_table = Table(
|
|
309
|
+
show_header=True,
|
|
310
|
+
show_edge=True,
|
|
311
|
+
box=SIMPLE_HEAD,
|
|
312
|
+
highlight=True,
|
|
313
|
+
)
|
|
297
314
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
315
|
+
self.requirement_table.title = "Model Requirements"
|
|
316
|
+
self.requirement_table.caption = (
|
|
317
|
+
"These parameters should be assigned priors inside a PyMC model block before "
|
|
318
|
+
"calling the build_statespace_graph method."
|
|
302
319
|
)
|
|
303
320
|
|
|
321
|
+
self.requirement_table.add_column("Variable", justify="left")
|
|
322
|
+
self.requirement_table.add_column("Shape", justify="left")
|
|
323
|
+
self.requirement_table.add_column("Constraints", justify="left")
|
|
324
|
+
self.requirement_table.add_column("Dimensions", justify="right")
|
|
325
|
+
|
|
304
326
|
def _unpack_statespace_with_placeholders(
|
|
305
327
|
self,
|
|
306
328
|
) -> tuple[
|
|
@@ -961,10 +983,31 @@ class PyMCStateSpace:
|
|
|
961
983
|
list[pm.Flat]
|
|
962
984
|
A list of pm.Flat variables representing all parameters estimated by the model.
|
|
963
985
|
"""
|
|
986
|
+
|
|
987
|
+
def infer_variable_shape(name):
|
|
988
|
+
shape = self._name_to_variable[name].type.shape
|
|
989
|
+
if not any(dim is None for dim in shape):
|
|
990
|
+
return shape
|
|
991
|
+
|
|
992
|
+
dim_names = self._fit_dims.get(name, None)
|
|
993
|
+
if dim_names is None:
|
|
994
|
+
raise ValueError(
|
|
995
|
+
f"Could not infer shape for {name}, because it was not given coords during model"
|
|
996
|
+
f"fitting"
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
|
|
1000
|
+
return tuple(
|
|
1001
|
+
[
|
|
1002
|
+
shape[i] if shape[i] is not None else shape_from_coords[i]
|
|
1003
|
+
for i in range(len(shape))
|
|
1004
|
+
]
|
|
1005
|
+
)
|
|
1006
|
+
|
|
964
1007
|
for name in self.param_names:
|
|
965
1008
|
pm.Flat(
|
|
966
1009
|
name,
|
|
967
|
-
shape=
|
|
1010
|
+
shape=infer_variable_shape(name),
|
|
968
1011
|
dims=self._fit_dims.get(name, None),
|
|
969
1012
|
)
|
|
970
1013
|
|
pymc_extras/version.txt
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.2.
|
|
1
|
+
0.2.2
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -34,6 +34,17 @@ Provides-Extra: dev
|
|
|
34
34
|
Requires-Dist: dask[all]; extra == "dev"
|
|
35
35
|
Requires-Dist: blackjax; extra == "dev"
|
|
36
36
|
Requires-Dist: statsmodels; extra == "dev"
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: license
|
|
42
|
+
Dynamic: maintainer
|
|
43
|
+
Dynamic: maintainer-email
|
|
44
|
+
Dynamic: provides-extra
|
|
45
|
+
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
47
|
+
Dynamic: summary
|
|
37
48
|
|
|
38
49
|
# Welcome to `pymc-extras`
|
|
39
50
|
<a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
|
|
@@ -3,7 +3,7 @@ pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,39
|
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=mY9riH7Xpu9E6EIQ0CN7cVvtQupyVPSNxBiv7ApIQQM,6
|
|
7
7
|
pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
@@ -15,13 +15,16 @@ pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,74
|
|
|
15
15
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
16
16
|
pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
|
|
17
17
|
pymc_extras/inference/find_map.py,sha256=T0uO8prUI5aBNuR1AN8fbA4cHmLRQLXznwJrfxfe7CA,15723
|
|
18
|
-
pymc_extras/inference/fit.py,sha256=
|
|
18
|
+
pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
|
|
19
19
|
pymc_extras/inference/laplace.py,sha256=OglOvnxfHLe0VXxBC1-ddVzADR9zgGxUPScM6P6FYo8,21163
|
|
20
|
-
pymc_extras/inference/pathfinder.py,sha256=
|
|
20
|
+
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
21
|
+
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
|
|
22
|
+
pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
|
|
23
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
|
|
21
24
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
22
25
|
pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
|
|
23
26
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
-
pymc_extras/model/model_api.py,sha256=
|
|
27
|
+
pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
|
|
25
28
|
pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
29
|
pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
|
|
27
30
|
pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
|
|
@@ -34,7 +37,7 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
|
|
|
34
37
|
pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
|
|
35
38
|
pymc_extras/statespace/core/compile.py,sha256=1c8Q9D9zeUe7F0z7CH6q1C6ZuLg2_imgk8RoE_KMaFI,1608
|
|
36
39
|
pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
|
|
37
|
-
pymc_extras/statespace/core/statespace.py,sha256=
|
|
40
|
+
pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
|
|
38
41
|
pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
|
|
39
42
|
pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
|
|
40
43
|
pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
|
|
@@ -63,7 +66,7 @@ tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6
|
|
|
63
66
|
tests/test_laplace.py,sha256=5ioEyP6AzmMszrtQRz0KWTsCCU35SEhSOdBcYfYzptE,8228
|
|
64
67
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
65
68
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
66
|
-
tests/test_pathfinder.py,sha256=
|
|
69
|
+
tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
|
|
67
70
|
tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
|
|
68
71
|
tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
|
|
69
72
|
tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
|
|
@@ -75,7 +78,7 @@ tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcl
|
|
|
75
78
|
tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
|
|
76
79
|
tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
|
|
77
80
|
tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
|
-
tests/model/test_model_api.py,sha256=
|
|
81
|
+
tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
|
|
79
82
|
tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
80
83
|
tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
|
|
81
84
|
tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
|
|
@@ -88,15 +91,15 @@ tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Z
|
|
|
88
91
|
tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
|
|
89
92
|
tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
|
|
90
93
|
tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
|
|
91
|
-
tests/statespace/test_statespace.py,sha256=
|
|
94
|
+
tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
|
|
92
95
|
tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
|
|
93
96
|
tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
|
|
94
97
|
tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
95
98
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
96
99
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
97
100
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
98
|
-
pymc_extras-0.2.
|
|
99
|
-
pymc_extras-0.2.
|
|
100
|
-
pymc_extras-0.2.
|
|
101
|
-
pymc_extras-0.2.
|
|
102
|
-
pymc_extras-0.2.
|
|
101
|
+
pymc_extras-0.2.2.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
102
|
+
pymc_extras-0.2.2.dist-info/METADATA,sha256=9k60kKKNzr7E24gACpTBNOWj-tRTNOiujSZOFD89G5c,5140
|
|
103
|
+
pymc_extras-0.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
104
|
+
pymc_extras-0.2.2.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
105
|
+
pymc_extras-0.2.2.dist-info/RECORD,,
|
tests/model/test_model_api.py
CHANGED
|
@@ -25,5 +25,14 @@ def test_logp():
|
|
|
25
25
|
|
|
26
26
|
mw2 = model_wrapped2(coords=coords)
|
|
27
27
|
|
|
28
|
+
@pmx.as_model()
|
|
29
|
+
def model_wrapped3(mu):
|
|
30
|
+
pm.Normal("x", mu, 1.0, dims="obs")
|
|
31
|
+
|
|
32
|
+
mw3 = model_wrapped3(0.0, coords=coords)
|
|
33
|
+
mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
|
|
34
|
+
|
|
28
35
|
np.testing.assert_equal(model.point_logps(), mw.point_logps())
|
|
29
36
|
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
|
|
37
|
+
assert mw3["mu"] in mw3.data_vars
|
|
38
|
+
assert "mu" not in mw4
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
1
2
|
from functools import partial
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
|
|
|
349
350
|
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
|
|
350
351
|
|
|
351
352
|
|
|
353
|
+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
|
|
354
|
+
def test_sample_conditional_with_time_varying():
|
|
355
|
+
class TVCovariance(PyMCStateSpace):
|
|
356
|
+
def __init__(self):
|
|
357
|
+
super().__init__(k_states=1, k_endog=1, k_posdef=1)
|
|
358
|
+
|
|
359
|
+
def make_symbolic_graph(self) -> None:
|
|
360
|
+
self.ssm["transition", 0, 0] = 1.0
|
|
361
|
+
|
|
362
|
+
self.ssm["design", 0, 0] = 1.0
|
|
363
|
+
|
|
364
|
+
sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
|
|
365
|
+
self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def param_names(self) -> list[str]:
|
|
369
|
+
return ["sigma_cov"]
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def coords(self) -> dict[str, Sequence[str]]:
|
|
373
|
+
return make_default_coords(self)
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def state_names(self) -> list[str]:
|
|
377
|
+
return ["level"]
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def observed_states(self) -> list[str]:
|
|
381
|
+
return ["level"]
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def shock_names(self) -> list[str]:
|
|
385
|
+
return ["level"]
|
|
386
|
+
|
|
387
|
+
ss_mod = TVCovariance()
|
|
388
|
+
empty_data = pd.DataFrame(
|
|
389
|
+
np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
coords = ss_mod.coords
|
|
393
|
+
coords["time"] = empty_data.index
|
|
394
|
+
with pm.Model(coords=coords) as mod:
|
|
395
|
+
log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
|
|
396
|
+
pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
|
|
397
|
+
|
|
398
|
+
ss_mod.build_statespace_graph(data=empty_data)
|
|
399
|
+
|
|
400
|
+
prior = pm.sample_prior_predictive(10)
|
|
401
|
+
|
|
402
|
+
ss_mod.sample_unconditional_prior(prior)
|
|
403
|
+
ss_mod.sample_conditional_prior(prior)
|
|
404
|
+
|
|
405
|
+
|
|
352
406
|
def _make_time_idx(mod, use_datetime_index=True):
|
|
353
407
|
if use_datetime_index:
|
|
354
408
|
mod._fit_coords["time"] = nile.index
|
tests/test_pathfinder.py
CHANGED
|
@@ -18,12 +18,12 @@ import numpy as np
|
|
|
18
18
|
import pymc as pm
|
|
19
19
|
import pytest
|
|
20
20
|
|
|
21
|
+
pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
|
|
22
|
+
|
|
21
23
|
import pymc_extras as pmx
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
|
|
25
|
-
def test_pathfinder():
|
|
26
|
-
# Data of the Eight Schools Model
|
|
26
|
+
def eight_schools_model() -> pm.Model:
|
|
27
27
|
J = 8
|
|
28
28
|
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
|
|
29
29
|
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
|
|
@@ -35,11 +35,139 @@ def test_pathfinder():
|
|
|
35
35
|
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
|
|
36
36
|
obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
|
|
37
37
|
|
|
38
|
-
|
|
38
|
+
return model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture
|
|
42
|
+
def reference_idata():
|
|
43
|
+
model = eight_schools_model()
|
|
44
|
+
with model:
|
|
45
|
+
idata = pmx.fit(
|
|
46
|
+
method="pathfinder",
|
|
47
|
+
num_paths=50,
|
|
48
|
+
jitter=10.0,
|
|
49
|
+
random_seed=41,
|
|
50
|
+
inference_backend="pymc",
|
|
51
|
+
)
|
|
52
|
+
return idata
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
|
|
56
|
+
def test_pathfinder(inference_backend, reference_idata):
|
|
57
|
+
if inference_backend == "blackjax" and sys.platform == "win32":
|
|
58
|
+
pytest.skip("JAX not supported on windows")
|
|
59
|
+
|
|
60
|
+
if inference_backend == "blackjax":
|
|
61
|
+
model = eight_schools_model()
|
|
62
|
+
with model:
|
|
63
|
+
idata = pmx.fit(
|
|
64
|
+
method="pathfinder",
|
|
65
|
+
num_paths=50,
|
|
66
|
+
jitter=10.0,
|
|
67
|
+
random_seed=41,
|
|
68
|
+
inference_backend=inference_backend,
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
idata = reference_idata
|
|
72
|
+
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
|
|
73
|
+
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
|
|
39
74
|
|
|
40
75
|
assert idata.posterior["mu"].shape == (1, 1000)
|
|
41
76
|
assert idata.posterior["tau"].shape == (1, 1000)
|
|
42
77
|
assert idata.posterior["theta"].shape == (1, 1000, 8)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@pytest.mark.parametrize("concurrent", ["thread", "process"])
|
|
81
|
+
def test_concurrent_results(reference_idata, concurrent):
|
|
82
|
+
model = eight_schools_model()
|
|
83
|
+
with model:
|
|
84
|
+
idata_conc = pmx.fit(
|
|
85
|
+
method="pathfinder",
|
|
86
|
+
num_paths=50,
|
|
87
|
+
jitter=10.0,
|
|
88
|
+
random_seed=41,
|
|
89
|
+
inference_backend="pymc",
|
|
90
|
+
concurrent=concurrent,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
np.testing.assert_allclose(
|
|
94
|
+
reference_idata.posterior.mu.data.mean(),
|
|
95
|
+
idata_conc.posterior.mu.data.mean(),
|
|
96
|
+
atol=0.4,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
np.testing.assert_allclose(
|
|
100
|
+
reference_idata.posterior.tau.data.mean(),
|
|
101
|
+
idata_conc.posterior.tau.data.mean(),
|
|
102
|
+
atol=0.4,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_seed(reference_idata):
|
|
107
|
+
model = eight_schools_model()
|
|
108
|
+
with model:
|
|
109
|
+
idata_41 = pmx.fit(
|
|
110
|
+
method="pathfinder",
|
|
111
|
+
num_paths=50,
|
|
112
|
+
jitter=10.0,
|
|
113
|
+
random_seed=41,
|
|
114
|
+
inference_backend="pymc",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
idata_123 = pmx.fit(
|
|
118
|
+
method="pathfinder",
|
|
119
|
+
num_paths=50,
|
|
120
|
+
jitter=10.0,
|
|
121
|
+
random_seed=123,
|
|
122
|
+
inference_backend="pymc",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
assert not np.allclose(idata_41.posterior.mu.data.mean(), idata_123.posterior.mu.data.mean())
|
|
126
|
+
|
|
127
|
+
assert np.allclose(idata_41.posterior.mu.data.mean(), idata_41.posterior.mu.data.mean())
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_bfgs_sample():
|
|
131
|
+
import pytensor.tensor as pt
|
|
132
|
+
|
|
133
|
+
from pymc_extras.inference.pathfinder.pathfinder import (
|
|
134
|
+
alpha_recover,
|
|
135
|
+
bfgs_sample,
|
|
136
|
+
inverse_hessian_factors,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
"""test BFGS sampling"""
|
|
140
|
+
Lp1, N = 8, 10
|
|
141
|
+
L = Lp1 - 1
|
|
142
|
+
J = 6
|
|
143
|
+
num_samples = 1000
|
|
144
|
+
|
|
145
|
+
# mock data
|
|
146
|
+
x_data = np.random.randn(Lp1, N)
|
|
147
|
+
g_data = np.random.randn(Lp1, N)
|
|
148
|
+
|
|
149
|
+
# get factors
|
|
150
|
+
x_full = pt.as_tensor(x_data, dtype="float64")
|
|
151
|
+
g_full = pt.as_tensor(g_data, dtype="float64")
|
|
152
|
+
epsilon = 1e-11
|
|
153
|
+
|
|
154
|
+
x = x_full[1:]
|
|
155
|
+
g = g_full[1:]
|
|
156
|
+
alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon)
|
|
157
|
+
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
|
|
158
|
+
|
|
159
|
+
# sample
|
|
160
|
+
phi, logq = bfgs_sample(
|
|
161
|
+
num_samples=num_samples,
|
|
162
|
+
x=x,
|
|
163
|
+
g=g,
|
|
164
|
+
alpha=alpha,
|
|
165
|
+
beta=beta,
|
|
166
|
+
gamma=gamma,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# check shapes
|
|
170
|
+
assert beta.eval().shape == (L, N, 2 * J)
|
|
171
|
+
assert gamma.eval().shape == (L, 2 * J, 2 * J)
|
|
172
|
+
assert phi.eval().shape == (L, num_samples, N)
|
|
173
|
+
assert logq.eval().shape == (L, num_samples)
|