pymc-extras 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +2 -0
- pymc_extras/inference/find_map.py +36 -16
- pymc_extras/inference/fit.py +0 -4
- pymc_extras/inference/laplace.py +17 -10
- pymc_extras/inference/pathfinder/__init__.py +3 -0
- pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
- pymc_extras/inference/pathfinder/lbfgs.py +190 -0
- pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
- pymc_extras/model/marginal/marginal_model.py +2 -1
- pymc_extras/model/model_api.py +18 -2
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/core/statespace.py +79 -36
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/METADATA +16 -4
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/RECORD +23 -20
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/WHEEL +1 -1
- tests/model/test_model_api.py +9 -0
- tests/statespace/test_statespace.py +54 -0
- tests/test_find_map.py +19 -14
- tests/test_laplace.py +42 -15
- tests/test_pathfinder.py +135 -7
- pymc_extras/inference/pathfinder.py +0 -134
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,8 @@ from pymc.model.fgraph import (
|
|
|
19
19
|
model_free_rv,
|
|
20
20
|
model_from_fgraph,
|
|
21
21
|
)
|
|
22
|
-
from pymc.pytensorf import collect_default_updates,
|
|
22
|
+
from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
|
|
23
|
+
from pymc.pytensorf import compile as compile_pymc
|
|
23
24
|
from pymc.util import RandomState, _get_seeds_per_chain
|
|
24
25
|
from pytensor import In, Out
|
|
25
26
|
from pytensor.compile import SharedVariable
|
pymc_extras/model/model_api.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from functools import wraps
|
|
2
|
+
from inspect import signature
|
|
2
3
|
|
|
3
|
-
|
|
4
|
+
import pytensor.tensor as pt
|
|
5
|
+
|
|
6
|
+
from pymc import Data, Model
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
def as_model(*model_args, **model_kwargs):
|
|
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
|
|
|
9
12
|
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
|
|
10
13
|
Additionally, a coords argument is added to the function so coords can be changed during function invocation
|
|
11
14
|
|
|
15
|
+
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
|
|
16
|
+
|
|
12
17
|
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
|
|
13
18
|
|
|
14
19
|
Examples
|
|
@@ -47,8 +52,19 @@ def as_model(*model_args, **model_kwargs):
|
|
|
47
52
|
@wraps(f)
|
|
48
53
|
def make_model(*args, **kwargs):
|
|
49
54
|
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
|
|
55
|
+
sig = signature(f)
|
|
56
|
+
ba = sig.bind(*args, **kwargs)
|
|
57
|
+
ba.apply_defaults()
|
|
58
|
+
|
|
50
59
|
with Model(*model_args, coords=coords, **model_kwargs) as m:
|
|
51
|
-
|
|
60
|
+
for name, v in ba.arguments.items():
|
|
61
|
+
# Only wrap pm.Data around values pytensor can process
|
|
62
|
+
try:
|
|
63
|
+
_ = pt.as_tensor_variable(v)
|
|
64
|
+
ba.arguments[name] = Data(name, v)
|
|
65
|
+
except (NotImplementedError, TypeError, ValueError):
|
|
66
|
+
pass
|
|
67
|
+
f(*ba.args, **ba.kwargs)
|
|
52
68
|
return m
|
|
53
69
|
|
|
54
70
|
return make_model
|
|
@@ -30,7 +30,7 @@ def compile_statespace(
|
|
|
30
30
|
|
|
31
31
|
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
|
|
32
32
|
|
|
33
|
-
_f = pm.
|
|
33
|
+
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
|
|
34
34
|
|
|
35
35
|
def f(*, draws=1, **params):
|
|
36
36
|
if isinstance(steps, pt.Variable):
|
|
@@ -15,6 +15,9 @@ from pymc.model.transform.optimization import freeze_dims_and_data
|
|
|
15
15
|
from pymc.util import RandomState
|
|
16
16
|
from pytensor import Variable, graph_replace
|
|
17
17
|
from pytensor.compile import get_mode
|
|
18
|
+
from rich.box import SIMPLE_HEAD
|
|
19
|
+
from rich.console import Console
|
|
20
|
+
from rich.table import Table
|
|
18
21
|
|
|
19
22
|
from pymc_extras.statespace.core.representation import PytensorRepresentation
|
|
20
23
|
from pymc_extras.statespace.filters import (
|
|
@@ -254,53 +257,72 @@ class PyMCStateSpace:
|
|
|
254
257
|
self.kalman_smoother = KalmanSmoother()
|
|
255
258
|
self.make_symbolic_graph()
|
|
256
259
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def _print_prior_requirements(self) -> None:
|
|
270
|
-
"""
|
|
271
|
-
Prints a short report to the terminal about the priors needed for the model, including their names,
|
|
260
|
+
self.requirement_table = None
|
|
261
|
+
self._populate_prior_requirements()
|
|
262
|
+
self._populate_data_requirements()
|
|
263
|
+
|
|
264
|
+
if verbose and self.requirement_table:
|
|
265
|
+
console = Console()
|
|
266
|
+
console.print(self.requirement_table)
|
|
267
|
+
|
|
268
|
+
def _populate_prior_requirements(self) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Add requirements about priors needed for the model to a rich table, including their names,
|
|
272
271
|
shapes, named dimensions, and any parameter constraints.
|
|
273
272
|
"""
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
273
|
+
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
|
|
274
|
+
# is not true.
|
|
275
|
+
try:
|
|
276
|
+
if not isinstance(self.param_info, dict):
|
|
277
|
+
return
|
|
278
|
+
except NotImplementedError:
|
|
279
|
+
return
|
|
278
280
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
281
|
+
if self.requirement_table is None:
|
|
282
|
+
self._initialize_requirement_table()
|
|
283
|
+
|
|
284
|
+
for param, info in self.param_info.items():
|
|
285
|
+
self.requirement_table.add_row(
|
|
286
|
+
param, str(info["shape"]), info["constraints"], str(info["dims"])
|
|
287
|
+
)
|
|
284
288
|
|
|
285
|
-
def
|
|
289
|
+
def _populate_data_requirements(self) -> None:
|
|
286
290
|
"""
|
|
287
|
-
|
|
288
|
-
and named dimensions.
|
|
291
|
+
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
|
|
289
292
|
"""
|
|
290
|
-
|
|
293
|
+
try:
|
|
294
|
+
if not isinstance(self.data_info, dict):
|
|
295
|
+
return
|
|
296
|
+
except NotImplementedError:
|
|
291
297
|
return
|
|
292
298
|
|
|
293
|
-
|
|
299
|
+
if self.requirement_table is None:
|
|
300
|
+
self._initialize_requirement_table()
|
|
301
|
+
else:
|
|
302
|
+
self.requirement_table.add_section()
|
|
303
|
+
|
|
294
304
|
for data, info in self.data_info.items():
|
|
295
|
-
|
|
296
|
-
|
|
305
|
+
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
|
|
306
|
+
|
|
307
|
+
def _initialize_requirement_table(self) -> None:
|
|
308
|
+
self.requirement_table = Table(
|
|
309
|
+
show_header=True,
|
|
310
|
+
show_edge=True,
|
|
311
|
+
box=SIMPLE_HEAD,
|
|
312
|
+
highlight=True,
|
|
313
|
+
)
|
|
297
314
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
315
|
+
self.requirement_table.title = "Model Requirements"
|
|
316
|
+
self.requirement_table.caption = (
|
|
317
|
+
"These parameters should be assigned priors inside a PyMC model block before "
|
|
318
|
+
"calling the build_statespace_graph method."
|
|
302
319
|
)
|
|
303
320
|
|
|
321
|
+
self.requirement_table.add_column("Variable", justify="left")
|
|
322
|
+
self.requirement_table.add_column("Shape", justify="left")
|
|
323
|
+
self.requirement_table.add_column("Constraints", justify="left")
|
|
324
|
+
self.requirement_table.add_column("Dimensions", justify="right")
|
|
325
|
+
|
|
304
326
|
def _unpack_statespace_with_placeholders(
|
|
305
327
|
self,
|
|
306
328
|
) -> tuple[
|
|
@@ -961,10 +983,31 @@ class PyMCStateSpace:
|
|
|
961
983
|
list[pm.Flat]
|
|
962
984
|
A list of pm.Flat variables representing all parameters estimated by the model.
|
|
963
985
|
"""
|
|
986
|
+
|
|
987
|
+
def infer_variable_shape(name):
|
|
988
|
+
shape = self._name_to_variable[name].type.shape
|
|
989
|
+
if not any(dim is None for dim in shape):
|
|
990
|
+
return shape
|
|
991
|
+
|
|
992
|
+
dim_names = self._fit_dims.get(name, None)
|
|
993
|
+
if dim_names is None:
|
|
994
|
+
raise ValueError(
|
|
995
|
+
f"Could not infer shape for {name}, because it was not given coords during model"
|
|
996
|
+
f"fitting"
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
|
|
1000
|
+
return tuple(
|
|
1001
|
+
[
|
|
1002
|
+
shape[i] if shape[i] is not None else shape_from_coords[i]
|
|
1003
|
+
for i in range(len(shape))
|
|
1004
|
+
]
|
|
1005
|
+
)
|
|
1006
|
+
|
|
964
1007
|
for name in self.param_names:
|
|
965
1008
|
pm.Flat(
|
|
966
1009
|
name,
|
|
967
|
-
shape=
|
|
1010
|
+
shape=infer_variable_shape(name),
|
|
968
1011
|
dims=self._fit_dims.get(name, None),
|
|
969
1012
|
)
|
|
970
1013
|
|
pymc_extras/version.txt
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.2.
|
|
1
|
+
0.2.3
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
7
7
|
Maintainer-email: pymc.devs@gmail.com
|
|
8
|
-
License: Apache
|
|
8
|
+
License: Apache-2.0
|
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
|
10
10
|
Classifier: Programming Language :: Python
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
|
|
|
20
20
|
Requires-Python: >=3.10
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pymc>=5.
|
|
23
|
+
Requires-Dist: pymc>=5.20
|
|
24
24
|
Requires-Dist: scikit-learn
|
|
25
|
+
Requires-Dist: better-optimize
|
|
25
26
|
Provides-Extra: dask-histogram
|
|
26
27
|
Requires-Dist: dask[complete]; extra == "dask-histogram"
|
|
27
28
|
Requires-Dist: xhistogram; extra == "dask-histogram"
|
|
@@ -34,6 +35,17 @@ Provides-Extra: dev
|
|
|
34
35
|
Requires-Dist: dask[all]; extra == "dev"
|
|
35
36
|
Requires-Dist: blackjax; extra == "dev"
|
|
36
37
|
Requires-Dist: statsmodels; extra == "dev"
|
|
38
|
+
Dynamic: classifier
|
|
39
|
+
Dynamic: description
|
|
40
|
+
Dynamic: description-content-type
|
|
41
|
+
Dynamic: home-page
|
|
42
|
+
Dynamic: license
|
|
43
|
+
Dynamic: maintainer
|
|
44
|
+
Dynamic: maintainer-email
|
|
45
|
+
Dynamic: provides-extra
|
|
46
|
+
Dynamic: requires-dist
|
|
47
|
+
Dynamic: requires-python
|
|
48
|
+
Dynamic: summary
|
|
37
49
|
|
|
38
50
|
# Welcome to `pymc-extras`
|
|
39
51
|
<a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
pymc_extras/__init__.py,sha256=
|
|
1
|
+
pymc_extras/__init__.py,sha256=IFIEZdPX_Ugq57Bu7jlyrJLpKng-P0FBAAAzl2pFXLE,1266
|
|
2
2
|
pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=OrlMBNJJhvOvKIuhzaLAu928Wonf8JcYKAX1RXjh6nU,6
|
|
7
7
|
pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
@@ -14,27 +14,30 @@ pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNP
|
|
|
14
14
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
15
15
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
16
16
|
pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
|
|
17
|
-
pymc_extras/inference/find_map.py,sha256=
|
|
18
|
-
pymc_extras/inference/fit.py,sha256=
|
|
19
|
-
pymc_extras/inference/laplace.py,sha256=
|
|
20
|
-
pymc_extras/inference/pathfinder.py,sha256=
|
|
17
|
+
pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
|
|
18
|
+
pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
|
|
19
|
+
pymc_extras/inference/laplace.py,sha256=uOZGp8ssQuhvCHV_Y_v3icsr4rhcYgr_qlr9dS7pcSM,21761
|
|
20
|
+
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
21
|
+
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
|
|
22
|
+
pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
|
|
23
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
|
|
21
24
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
22
25
|
pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
|
|
23
26
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
-
pymc_extras/model/model_api.py,sha256=
|
|
27
|
+
pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
|
|
25
28
|
pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
29
|
pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
|
|
27
30
|
pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
|
|
28
|
-
pymc_extras/model/marginal/marginal_model.py,sha256=
|
|
31
|
+
pymc_extras/model/marginal/marginal_model.py,sha256=oIdikaSnefCkyMxmzAe222qGXNucxZpHYk7548fK6iA,23631
|
|
29
32
|
pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
33
|
pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
|
|
31
34
|
pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
35
|
pymc_extras/preprocessing/standard_scaler.py,sha256=Vajp33ma6OkwlU54JYtSS8urHbMJ3CRiRFxZpvFNuus,600
|
|
33
36
|
pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53CF6ND0,429
|
|
34
37
|
pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
|
|
35
|
-
pymc_extras/statespace/core/compile.py,sha256=
|
|
38
|
+
pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
|
|
36
39
|
pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
|
|
37
|
-
pymc_extras/statespace/core/statespace.py,sha256=
|
|
40
|
+
pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
|
|
38
41
|
pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
|
|
39
42
|
pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
|
|
40
43
|
pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
|
|
@@ -58,12 +61,12 @@ pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,68
|
|
|
58
61
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
59
62
|
tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
|
|
60
63
|
tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
|
|
61
|
-
tests/test_find_map.py,sha256=
|
|
64
|
+
tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
|
|
62
65
|
tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
|
|
63
|
-
tests/test_laplace.py,sha256=
|
|
66
|
+
tests/test_laplace.py,sha256=u4o-0y4v1emaTMYr_rOyL_EKY_bQIz0DUXFuwuDbfNg,9314
|
|
64
67
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
65
68
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
66
|
-
tests/test_pathfinder.py,sha256=
|
|
69
|
+
tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
|
|
67
70
|
tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
|
|
68
71
|
tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
|
|
69
72
|
tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
|
|
@@ -75,7 +78,7 @@ tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcl
|
|
|
75
78
|
tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
|
|
76
79
|
tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
|
|
77
80
|
tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
|
-
tests/model/test_model_api.py,sha256=
|
|
81
|
+
tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
|
|
79
82
|
tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
80
83
|
tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
|
|
81
84
|
tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
|
|
@@ -88,15 +91,15 @@ tests/statespace/test_coord_assignment.py,sha256=2GBm46-0eI4QNh4bvp3D7az58stcA5Z
|
|
|
88
91
|
tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
|
|
89
92
|
tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
|
|
90
93
|
tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
|
|
91
|
-
tests/statespace/test_statespace.py,sha256=
|
|
94
|
+
tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
|
|
92
95
|
tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
|
|
93
96
|
tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
|
|
94
97
|
tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
95
98
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
96
99
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
97
100
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
98
|
-
pymc_extras-0.2.
|
|
99
|
-
pymc_extras-0.2.
|
|
100
|
-
pymc_extras-0.2.
|
|
101
|
-
pymc_extras-0.2.
|
|
102
|
-
pymc_extras-0.2.
|
|
101
|
+
pymc_extras-0.2.3.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
102
|
+
pymc_extras-0.2.3.dist-info/METADATA,sha256=ZTiMM7hvVRF3O_liRu4Aea_EuxJc4vHfTD2CbRRQrcU,5152
|
|
103
|
+
pymc_extras-0.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
104
|
+
pymc_extras-0.2.3.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
105
|
+
pymc_extras-0.2.3.dist-info/RECORD,,
|
tests/model/test_model_api.py
CHANGED
|
@@ -25,5 +25,14 @@ def test_logp():
|
|
|
25
25
|
|
|
26
26
|
mw2 = model_wrapped2(coords=coords)
|
|
27
27
|
|
|
28
|
+
@pmx.as_model()
|
|
29
|
+
def model_wrapped3(mu):
|
|
30
|
+
pm.Normal("x", mu, 1.0, dims="obs")
|
|
31
|
+
|
|
32
|
+
mw3 = model_wrapped3(0.0, coords=coords)
|
|
33
|
+
mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
|
|
34
|
+
|
|
28
35
|
np.testing.assert_equal(model.point_logps(), mw.point_logps())
|
|
29
36
|
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
|
|
37
|
+
assert mw3["mu"] in mw3.data_vars
|
|
38
|
+
assert "mu" not in mw4
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
1
2
|
from functools import partial
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
|
|
|
349
350
|
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
|
|
350
351
|
|
|
351
352
|
|
|
353
|
+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
|
|
354
|
+
def test_sample_conditional_with_time_varying():
|
|
355
|
+
class TVCovariance(PyMCStateSpace):
|
|
356
|
+
def __init__(self):
|
|
357
|
+
super().__init__(k_states=1, k_endog=1, k_posdef=1)
|
|
358
|
+
|
|
359
|
+
def make_symbolic_graph(self) -> None:
|
|
360
|
+
self.ssm["transition", 0, 0] = 1.0
|
|
361
|
+
|
|
362
|
+
self.ssm["design", 0, 0] = 1.0
|
|
363
|
+
|
|
364
|
+
sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
|
|
365
|
+
self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def param_names(self) -> list[str]:
|
|
369
|
+
return ["sigma_cov"]
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def coords(self) -> dict[str, Sequence[str]]:
|
|
373
|
+
return make_default_coords(self)
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def state_names(self) -> list[str]:
|
|
377
|
+
return ["level"]
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def observed_states(self) -> list[str]:
|
|
381
|
+
return ["level"]
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def shock_names(self) -> list[str]:
|
|
385
|
+
return ["level"]
|
|
386
|
+
|
|
387
|
+
ss_mod = TVCovariance()
|
|
388
|
+
empty_data = pd.DataFrame(
|
|
389
|
+
np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
coords = ss_mod.coords
|
|
393
|
+
coords["time"] = empty_data.index
|
|
394
|
+
with pm.Model(coords=coords) as mod:
|
|
395
|
+
log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
|
|
396
|
+
pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
|
|
397
|
+
|
|
398
|
+
ss_mod.build_statespace_graph(data=empty_data)
|
|
399
|
+
|
|
400
|
+
prior = pm.sample_prior_predictive(10)
|
|
401
|
+
|
|
402
|
+
ss_mod.sample_unconditional_prior(prior)
|
|
403
|
+
ss_mod.sample_conditional_prior(prior)
|
|
404
|
+
|
|
405
|
+
|
|
352
406
|
def _make_time_idx(mod, use_datetime_index=True):
|
|
353
407
|
if use_datetime_index:
|
|
354
408
|
mod._fit_coords["time"] = nile.index
|
tests/test_find_map.py
CHANGED
|
@@ -54,24 +54,28 @@ def test_jax_functions_from_graph(gradient_backend: GradientBackend):
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
@pytest.mark.parametrize(
|
|
57
|
-
"method, use_grad, use_hess",
|
|
57
|
+
"method, use_grad, use_hess, use_hessp",
|
|
58
58
|
[
|
|
59
|
-
("nelder-mead", False, False),
|
|
60
|
-
("powell", False, False),
|
|
61
|
-
("CG", True, False),
|
|
62
|
-
("BFGS", True, False),
|
|
63
|
-
("L-BFGS-B", True, False),
|
|
64
|
-
("TNC", True, False),
|
|
65
|
-
("SLSQP", True, False),
|
|
66
|
-
("dogleg", True, True),
|
|
67
|
-
("
|
|
68
|
-
("
|
|
69
|
-
("trust-
|
|
70
|
-
("trust-
|
|
59
|
+
("nelder-mead", False, False, False),
|
|
60
|
+
("powell", False, False, False),
|
|
61
|
+
("CG", True, False, False),
|
|
62
|
+
("BFGS", True, False, False),
|
|
63
|
+
("L-BFGS-B", True, False, False),
|
|
64
|
+
("TNC", True, False, False),
|
|
65
|
+
("SLSQP", True, False, False),
|
|
66
|
+
("dogleg", True, True, False),
|
|
67
|
+
("Newton-CG", True, True, False),
|
|
68
|
+
("Newton-CG", True, False, True),
|
|
69
|
+
("trust-ncg", True, True, False),
|
|
70
|
+
("trust-ncg", True, False, True),
|
|
71
|
+
("trust-exact", True, True, False),
|
|
72
|
+
("trust-krylov", True, True, False),
|
|
73
|
+
("trust-krylov", True, False, True),
|
|
74
|
+
("trust-constr", True, True, False),
|
|
71
75
|
],
|
|
72
76
|
)
|
|
73
77
|
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
|
|
74
|
-
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
|
|
78
|
+
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
|
|
75
79
|
extra_kwargs = {}
|
|
76
80
|
if method == "dogleg":
|
|
77
81
|
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
|
|
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
|
|
|
88
92
|
**extra_kwargs,
|
|
89
93
|
use_grad=use_grad,
|
|
90
94
|
use_hess=use_hess,
|
|
95
|
+
use_hessp=use_hessp,
|
|
91
96
|
progressbar=False,
|
|
92
97
|
gradient_backend=gradient_backend,
|
|
93
98
|
compile_kwargs={"mode": "JAX"},
|
tests/test_laplace.py
CHANGED
|
@@ -19,10 +19,10 @@ import pytest
|
|
|
19
19
|
|
|
20
20
|
import pymc_extras as pmx
|
|
21
21
|
|
|
22
|
-
from pymc_extras.inference.find_map import find_MAP
|
|
22
|
+
from pymc_extras.inference.find_map import GradientBackend, find_MAP
|
|
23
23
|
from pymc_extras.inference.laplace import (
|
|
24
24
|
fit_laplace,
|
|
25
|
-
|
|
25
|
+
fit_mvn_at_MAP,
|
|
26
26
|
sample_laplace_posterior,
|
|
27
27
|
)
|
|
28
28
|
|
|
@@ -37,7 +37,11 @@ def rng():
|
|
|
37
37
|
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
|
|
38
38
|
+ "To suppress this warning set `negate_output=False`:FutureWarning",
|
|
39
39
|
)
|
|
40
|
-
|
|
40
|
+
@pytest.mark.parametrize(
|
|
41
|
+
"mode, gradient_backend",
|
|
42
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
43
|
+
)
|
|
44
|
+
def test_laplace(mode, gradient_backend: GradientBackend):
|
|
41
45
|
# Example originates from Bayesian Data Analyses, 3rd Edition
|
|
42
46
|
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
|
|
43
47
|
# Aki Vehtari, and Donald Rubin.
|
|
@@ -55,7 +59,13 @@ def test_laplace():
|
|
|
55
59
|
vars = [mu, logsigma]
|
|
56
60
|
|
|
57
61
|
idata = pmx.fit(
|
|
58
|
-
method="laplace",
|
|
62
|
+
method="laplace",
|
|
63
|
+
optimize_method="trust-ncg",
|
|
64
|
+
draws=draws,
|
|
65
|
+
random_seed=173300,
|
|
66
|
+
chains=1,
|
|
67
|
+
compile_kwargs={"mode": mode},
|
|
68
|
+
gradient_backend=gradient_backend,
|
|
59
69
|
)
|
|
60
70
|
|
|
61
71
|
assert idata.posterior["mu"].shape == (1, draws)
|
|
@@ -71,7 +81,11 @@ def test_laplace():
|
|
|
71
81
|
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
|
|
72
82
|
|
|
73
83
|
|
|
74
|
-
|
|
84
|
+
@pytest.mark.parametrize(
|
|
85
|
+
"mode, gradient_backend",
|
|
86
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
87
|
+
)
|
|
88
|
+
def test_laplace_only_fit(mode, gradient_backend: GradientBackend):
|
|
75
89
|
# Example originates from Bayesian Data Analyses, 3rd Edition
|
|
76
90
|
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
|
|
77
91
|
# Aki Vehtari, and Donald Rubin.
|
|
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
|
|
|
90
104
|
method="laplace",
|
|
91
105
|
optimize_method="BFGS",
|
|
92
106
|
progressbar=True,
|
|
93
|
-
gradient_backend=
|
|
94
|
-
compile_kwargs={"mode":
|
|
107
|
+
gradient_backend=gradient_backend,
|
|
108
|
+
compile_kwargs={"mode": mode},
|
|
95
109
|
optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
|
|
96
110
|
random_seed=173300,
|
|
97
111
|
)
|
|
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
|
|
|
111
125
|
[True, False],
|
|
112
126
|
ids=["transformed", "untransformed"],
|
|
113
127
|
)
|
|
114
|
-
@pytest.mark.parametrize(
|
|
115
|
-
|
|
128
|
+
@pytest.mark.parametrize(
|
|
129
|
+
"mode, gradient_backend",
|
|
130
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
131
|
+
)
|
|
132
|
+
def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend):
|
|
116
133
|
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
|
|
117
134
|
with pm.Model(coords=coords) as model:
|
|
118
135
|
mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
|
|
@@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
|
|
|
131
148
|
use_hessp=True,
|
|
132
149
|
progressbar=False,
|
|
133
150
|
compile_kwargs=dict(mode=mode),
|
|
134
|
-
gradient_backend=
|
|
151
|
+
gradient_backend=gradient_backend,
|
|
135
152
|
)
|
|
136
153
|
|
|
137
154
|
for value in optimized_point.values():
|
|
138
155
|
assert value.shape == (3,)
|
|
139
156
|
|
|
140
|
-
mu, H_inv =
|
|
157
|
+
mu, H_inv = fit_mvn_at_MAP(
|
|
141
158
|
optimized_point=optimized_point,
|
|
142
159
|
model=model,
|
|
143
160
|
transform_samples=transform_samples,
|
|
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
|
|
|
163
180
|
]
|
|
164
181
|
|
|
165
182
|
|
|
166
|
-
|
|
183
|
+
@pytest.mark.parametrize(
|
|
184
|
+
"mode, gradient_backend",
|
|
185
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
186
|
+
)
|
|
187
|
+
def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng):
|
|
167
188
|
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
|
|
168
189
|
with pm.Model(coords=coords) as ragged_dim_model:
|
|
169
190
|
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
|
|
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
|
|
|
188
209
|
progressbar=False,
|
|
189
210
|
use_grad=True,
|
|
190
211
|
use_hessp=True,
|
|
191
|
-
gradient_backend=
|
|
192
|
-
compile_kwargs={"mode":
|
|
212
|
+
gradient_backend=gradient_backend,
|
|
213
|
+
compile_kwargs={"mode": mode},
|
|
193
214
|
)
|
|
194
215
|
|
|
195
216
|
assert idata["posterior"].beta.shape[-2:] == (3, 2)
|
|
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
|
|
|
206
227
|
[True, False],
|
|
207
228
|
ids=["transformed", "untransformed"],
|
|
208
229
|
)
|
|
209
|
-
|
|
230
|
+
@pytest.mark.parametrize(
|
|
231
|
+
"mode, gradient_backend",
|
|
232
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
233
|
+
)
|
|
234
|
+
def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend):
|
|
210
235
|
with pm.Model() as simp_model:
|
|
211
236
|
mu = pm.Normal("mu", mu=3, sigma=0.5)
|
|
212
237
|
sigma = pm.Exponential("sigma", 1)
|
|
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
|
|
|
223
248
|
use_hessp=True,
|
|
224
249
|
fit_in_unconstrained_space=fit_in_unconstrained_space,
|
|
225
250
|
optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
|
|
251
|
+
compile_kwargs={"mode": mode},
|
|
252
|
+
gradient_backend=gradient_backend,
|
|
226
253
|
)
|
|
227
254
|
|
|
228
255
|
np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)
|