pymc-extras 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/PKG-INFO +4 -3
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/__init__.py +2 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/find_map.py +36 -16
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/laplace.py +17 -10
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/importance_sampling.py +23 -17
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/pathfinder.py +55 -23
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/marginal_model.py +2 -1
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/compile.py +1 -1
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/statespace.py +5 -4
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/distributions.py +9 -45
- pymc_extras-0.2.4/pymc_extras/version.txt +1 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/PKG-INFO +4 -3
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/requires.txt +2 -1
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pyproject.toml +3 -0
- pymc_extras-0.2.4/requirements.txt +3 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/setup.py +1 -1
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_find_map.py +19 -14
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_laplace.py +42 -15
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_pathfinder.py +40 -10
- pymc_extras-0.2.2/pymc_extras/version.txt +0 -1
- pymc_extras-0.2.2/requirements.txt +0 -2
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/CONTRIBUTING.md +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/LICENSE +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/MANIFEST.in +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/README.md +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/continuous.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/histogram_utils.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/timeseries.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/fit.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/smc/sampling.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/linearmodel.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/distributions.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/graph_analysis.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/model_api.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/transforms/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model_builder.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/preprocessing/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/printing.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/representation.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/utilities.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/ETS.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/SARIMAX.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/VARMAX.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/structural.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/utilities.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/constants.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/data_tools.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/model_equivalence.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/prior.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/spline.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/version.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/SOURCES.txt +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/dependency_links.txt +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/top_level.txt +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/requirements-dev.txt +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/requirements-docs.txt +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/setup.cfg +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_continuous.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_discrete.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/test_distributions.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/test_graph_analysis.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/test_marginal_model.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/test_model_api.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_ETS.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_SARIMAX.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_VARMAX.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_coord_assignment.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_distributions.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_kalman_filter.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_representation.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_statespace.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_statespace_JAX.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_structural.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/__init__.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/shared_fixtures.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/test_helpers.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_histogram_approximation.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_printing.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_prior_from_trace.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_splines.py +0 -0
- {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/utils.py +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
7
7
|
Maintainer-email: pymc.devs@gmail.com
|
|
8
|
-
License: Apache
|
|
8
|
+
License: Apache-2.0
|
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
|
10
10
|
Classifier: Programming Language :: Python
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
|
|
|
20
20
|
Requires-Python: >=3.10
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pymc>=5.
|
|
23
|
+
Requires-Dist: pymc>=5.21.1
|
|
24
24
|
Requires-Dist: scikit-learn
|
|
25
|
+
Requires-Dist: better-optimize
|
|
25
26
|
Provides-Extra: dask-histogram
|
|
26
27
|
Requires-Dist: dask[complete]; extra == "dask-histogram"
|
|
27
28
|
Requires-Dist: xhistogram; extra == "dask-histogram"
|
|
@@ -15,7 +15,9 @@ import logging
|
|
|
15
15
|
|
|
16
16
|
from pymc_extras import gp, statespace, utils
|
|
17
17
|
from pymc_extras.distributions import *
|
|
18
|
+
from pymc_extras.inference.find_map import find_MAP
|
|
18
19
|
from pymc_extras.inference.fit import fit
|
|
20
|
+
from pymc_extras.inference.laplace import fit_laplace
|
|
19
21
|
from pymc_extras.model.marginal.marginal_model import (
|
|
20
22
|
MarginalModel,
|
|
21
23
|
marginalize,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from importlib.util import find_spec
|
|
4
5
|
from typing import Literal, cast, get_args
|
|
5
6
|
|
|
6
|
-
import jax
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pymc as pm
|
|
9
9
|
import pytensor
|
|
@@ -30,13 +30,29 @@ VALID_BACKENDS = get_args(GradientBackend)
|
|
|
30
30
|
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
31
31
|
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
32
32
|
|
|
33
|
-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
34
|
-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
|
|
35
|
-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
|
|
36
|
-
|
|
37
33
|
if use_hess and use_hessp:
|
|
34
|
+
_log.warning(
|
|
35
|
+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
36
|
+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
37
|
+
'Setting "use_hess" to False.'
|
|
38
|
+
)
|
|
38
39
|
use_hess = False
|
|
39
40
|
|
|
41
|
+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
42
|
+
|
|
43
|
+
if use_hessp is not None and use_hess is None:
|
|
44
|
+
use_hess = not use_hessp
|
|
45
|
+
|
|
46
|
+
elif use_hess is not None and use_hessp is None:
|
|
47
|
+
use_hessp = not use_hess
|
|
48
|
+
|
|
49
|
+
elif use_hessp is None and use_hess is None:
|
|
50
|
+
use_hessp = method_info["uses_hessp"]
|
|
51
|
+
use_hess = method_info["uses_hess"]
|
|
52
|
+
if use_hessp and use_hess:
|
|
53
|
+
# If a method could use either hess or hessp, we default to using hessp
|
|
54
|
+
use_hess = False
|
|
55
|
+
|
|
40
56
|
return use_grad, use_hess, use_hessp
|
|
41
57
|
|
|
42
58
|
|
|
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
|
59
75
|
The nearest positive semi-definite matrix to the input matrix.
|
|
60
76
|
"""
|
|
61
77
|
C = (A + A.T) / 2
|
|
62
|
-
eigval, eigvec = np.linalg.
|
|
78
|
+
eigval, eigvec = np.linalg.eigh(C)
|
|
63
79
|
eigval[eigval < 0] = 0
|
|
64
80
|
|
|
65
81
|
return eigvec @ np.diag(eigval) @ eigvec.T
|
|
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
|
|
|
97
113
|
return f_untransform(posterior_draws)
|
|
98
114
|
|
|
99
115
|
|
|
100
|
-
def
|
|
116
|
+
def _compile_grad_and_hess_to_jax(
|
|
101
117
|
f_loss: Function, use_hess: bool, use_hessp: bool
|
|
102
118
|
) -> tuple[Callable | None, Callable | None]:
|
|
103
119
|
"""
|
|
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
|
|
|
122
138
|
f_hessp: Callable | None
|
|
123
139
|
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
124
140
|
"""
|
|
141
|
+
import jax
|
|
142
|
+
|
|
125
143
|
f_hess = None
|
|
126
144
|
f_hessp = None
|
|
127
145
|
|
|
@@ -152,7 +170,7 @@ def _compile_jax_gradients(
|
|
|
152
170
|
return f_loss_and_grad, f_hess, f_hessp
|
|
153
171
|
|
|
154
172
|
|
|
155
|
-
def
|
|
173
|
+
def _compile_functions_for_scipy_optimize(
|
|
156
174
|
loss: TensorVariable,
|
|
157
175
|
inputs: list[TensorVariable],
|
|
158
176
|
compute_grad: bool,
|
|
@@ -177,7 +195,7 @@ def _compile_functions(
|
|
|
177
195
|
compute_hessp: bool
|
|
178
196
|
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
179
197
|
compile_kwargs: dict, optional
|
|
180
|
-
Additional keyword arguments to pass to the ``pm.
|
|
198
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
181
199
|
|
|
182
200
|
Returns
|
|
183
201
|
-------
|
|
@@ -193,19 +211,19 @@ def _compile_functions(
|
|
|
193
211
|
if compute_grad:
|
|
194
212
|
grads = pytensor.gradient.grad(loss, inputs)
|
|
195
213
|
grad = pt.concatenate([grad.ravel() for grad in grads])
|
|
196
|
-
f_loss_and_grad = pm.
|
|
214
|
+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
|
|
197
215
|
else:
|
|
198
|
-
f_loss = pm.
|
|
216
|
+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
|
|
199
217
|
return [f_loss]
|
|
200
218
|
|
|
201
219
|
if compute_hess:
|
|
202
220
|
hess = pytensor.gradient.jacobian(grad, inputs)[0]
|
|
203
|
-
f_hess = pm.
|
|
221
|
+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
|
|
204
222
|
|
|
205
223
|
if compute_hessp:
|
|
206
224
|
p = pt.tensor("p", shape=inputs[0].type.shape)
|
|
207
225
|
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
|
|
208
|
-
f_hessp = pm.
|
|
226
|
+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
|
|
209
227
|
|
|
210
228
|
return [f_loss_and_grad, f_hess, f_hessp]
|
|
211
229
|
|
|
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
240
258
|
gradient_backend: str, default "pytensor"
|
|
241
259
|
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
|
|
242
260
|
compile_kwargs:
|
|
243
|
-
Additional keyword arguments to pass to the ``pm.
|
|
261
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
244
262
|
|
|
245
263
|
Returns
|
|
246
264
|
-------
|
|
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
|
|
|
265
283
|
)
|
|
266
284
|
|
|
267
285
|
use_jax_gradients = (gradient_backend == "jax") and use_grad
|
|
286
|
+
if use_jax_gradients and not find_spec("jax"):
|
|
287
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
268
288
|
|
|
269
289
|
mode = compile_kwargs.get("mode", None)
|
|
270
290
|
if mode is None and use_jax_gradients:
|
|
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
285
305
|
compute_hess = use_hess and not use_jax_gradients
|
|
286
306
|
compute_hessp = use_hessp and not use_jax_gradients
|
|
287
307
|
|
|
288
|
-
funcs =
|
|
308
|
+
funcs = _compile_functions_for_scipy_optimize(
|
|
289
309
|
loss=loss,
|
|
290
310
|
inputs=[flat_input],
|
|
291
311
|
compute_grad=compute_grad,
|
|
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
301
321
|
|
|
302
322
|
if use_jax_gradients:
|
|
303
323
|
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
|
|
304
|
-
f_loss, f_hess, f_hessp =
|
|
324
|
+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
|
|
305
325
|
|
|
306
326
|
return f_loss, f_hess, f_hessp
|
|
307
327
|
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from functools import reduce
|
|
19
|
+
from importlib.util import find_spec
|
|
19
20
|
from itertools import product
|
|
20
21
|
from typing import Literal
|
|
21
22
|
|
|
@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
|
|
|
231
232
|
return idata
|
|
232
233
|
|
|
233
234
|
|
|
234
|
-
def
|
|
235
|
+
def fit_mvn_at_MAP(
|
|
235
236
|
optimized_point: dict[str, np.ndarray],
|
|
236
237
|
model: pm.Model | None = None,
|
|
237
238
|
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
|
|
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
|
|
|
276
277
|
inverse_hessian: np.ndarray
|
|
277
278
|
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
278
279
|
"""
|
|
280
|
+
if gradient_backend == "jax" and not find_spec("jax"):
|
|
281
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
282
|
+
|
|
279
283
|
model = pm.modelcontext(model)
|
|
280
284
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
281
285
|
frozen_model = freeze_dims_and_data(model)
|
|
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
|
|
|
344
348
|
|
|
345
349
|
Parameters
|
|
346
350
|
----------
|
|
347
|
-
mu
|
|
348
|
-
|
|
351
|
+
mu: RaveledVars
|
|
352
|
+
The MAP estimate of the model parameters.
|
|
353
|
+
H_inv: np.ndarray
|
|
354
|
+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
349
355
|
model : Model
|
|
350
356
|
A PyMC model
|
|
351
357
|
chains : int
|
|
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
|
|
|
384
390
|
constrained_rvs, replace={unconstrained_vector: batched_values}
|
|
385
391
|
)
|
|
386
392
|
|
|
387
|
-
f_constrain = pm.
|
|
388
|
-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
|
|
389
|
-
)
|
|
393
|
+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
|
|
390
394
|
posterior_draws = f_constrain(posterior_draws)
|
|
391
395
|
|
|
392
396
|
else:
|
|
@@ -472,15 +476,17 @@ def fit_laplace(
|
|
|
472
476
|
and 1).
|
|
473
477
|
|
|
474
478
|
.. warning::
|
|
475
|
-
This
|
|
479
|
+
This argument should be considered highly experimental. It has not been verified if this method produces
|
|
476
480
|
valid draws from the posterior. **Use at your own risk**.
|
|
477
481
|
|
|
478
482
|
gradient_backend: str, default "pytensor"
|
|
479
483
|
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
480
484
|
chains: int, default: 2
|
|
481
|
-
The number of
|
|
485
|
+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
486
|
+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
487
|
+
compatible with the ArviZ library.
|
|
482
488
|
draws: int, default: 500
|
|
483
|
-
The number of samples to draw from the approximated posterior.
|
|
489
|
+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
|
|
484
490
|
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
|
|
485
491
|
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
|
|
486
492
|
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
|
|
@@ -547,11 +553,12 @@ def fit_laplace(
|
|
|
547
553
|
**optimizer_kwargs,
|
|
548
554
|
)
|
|
549
555
|
|
|
550
|
-
mu, H_inv =
|
|
556
|
+
mu, H_inv = fit_mvn_at_MAP(
|
|
551
557
|
optimized_point=optimized_point,
|
|
552
558
|
model=model,
|
|
553
559
|
on_bad_cov=on_bad_cov,
|
|
554
560
|
transform_samples=fit_in_unconstrained_space,
|
|
561
|
+
gradient_backend=gradient_backend,
|
|
555
562
|
zero_tol=zero_tol,
|
|
556
563
|
diag_jitter=diag_jitter,
|
|
557
564
|
compile_kwargs=compile_kwargs,
|
{pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/importance_sampling.py
RENAMED
|
@@ -20,7 +20,7 @@ class ImportanceSamplingResult:
|
|
|
20
20
|
samples: NDArray
|
|
21
21
|
pareto_k: float | None = None
|
|
22
22
|
warnings: list[str] = field(default_factory=list)
|
|
23
|
-
method: str = "
|
|
23
|
+
method: str = "psis"
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def importance_sampling(
|
|
@@ -28,7 +28,7 @@ def importance_sampling(
|
|
|
28
28
|
logP: NDArray,
|
|
29
29
|
logQ: NDArray,
|
|
30
30
|
num_draws: int,
|
|
31
|
-
method: Literal["psis", "psir", "identity"
|
|
31
|
+
method: Literal["psis", "psir", "identity"] | None,
|
|
32
32
|
random_seed: int | None = None,
|
|
33
33
|
) -> ImportanceSamplingResult:
|
|
34
34
|
"""Pareto Smoothed Importance Resampling (PSIR)
|
|
@@ -44,8 +44,15 @@ def importance_sampling(
|
|
|
44
44
|
log probability values of proposal distribution, shape (L, M)
|
|
45
45
|
num_draws : int
|
|
46
46
|
number of draws to return where num_draws <= samples.shape[0]
|
|
47
|
-
method : str, optional
|
|
48
|
-
|
|
47
|
+
method : str, None, optional
|
|
48
|
+
Method to apply sampling based on log importance weights (logP - logQ).
|
|
49
|
+
Options are:
|
|
50
|
+
"psis" : Pareto Smoothed Importance Sampling (default)
|
|
51
|
+
Recommended for more stable results.
|
|
52
|
+
"psir" : Pareto Smoothed Importance Resampling
|
|
53
|
+
Less stable than PSIS.
|
|
54
|
+
"identity" : Applies log importance weights directly without resampling.
|
|
55
|
+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
|
|
49
56
|
random_seed : int | None
|
|
50
57
|
|
|
51
58
|
Returns
|
|
@@ -71,11 +78,11 @@ def importance_sampling(
|
|
|
71
78
|
warnings = []
|
|
72
79
|
num_paths, _, N = samples.shape
|
|
73
80
|
|
|
74
|
-
if method
|
|
81
|
+
if method is None:
|
|
75
82
|
warnings.append(
|
|
76
83
|
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
|
|
77
84
|
)
|
|
78
|
-
return ImportanceSamplingResult(samples=samples, warnings=warnings)
|
|
85
|
+
return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
|
|
79
86
|
else:
|
|
80
87
|
samples = samples.reshape(-1, N)
|
|
81
88
|
logP = logP.ravel()
|
|
@@ -91,17 +98,16 @@ def importance_sampling(
|
|
|
91
98
|
_warnings.filterwarnings(
|
|
92
99
|
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
|
|
93
100
|
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
raise ValueError(f"Invalid importance sampling method: {method}")
|
|
101
|
+
match method:
|
|
102
|
+
case "psis":
|
|
103
|
+
replace = False
|
|
104
|
+
logiw, pareto_k = az.psislw(logiw)
|
|
105
|
+
case "psir":
|
|
106
|
+
replace = True
|
|
107
|
+
logiw, pareto_k = az.psislw(logiw)
|
|
108
|
+
case "identity":
|
|
109
|
+
replace = False
|
|
110
|
+
pareto_k = None
|
|
105
111
|
|
|
106
112
|
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
|
|
107
113
|
# Pareto k may not be a good diagnostic for Pathfinder.
|
|
@@ -60,6 +60,7 @@ from pytensor.graph import Apply, Op, vectorize_graph
|
|
|
60
60
|
from pytensor.tensor import TensorConstant, TensorVariable
|
|
61
61
|
from rich.console import Console, Group
|
|
62
62
|
from rich.padding import Padding
|
|
63
|
+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
|
63
64
|
from rich.table import Table
|
|
64
65
|
from rich.text import Text
|
|
65
66
|
|
|
@@ -155,7 +156,7 @@ def convert_flat_trace_to_idata(
|
|
|
155
156
|
postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
|
|
156
157
|
inference_backend: Literal["pymc", "blackjax"] = "pymc",
|
|
157
158
|
model: Model | None = None,
|
|
158
|
-
importance_sampling: Literal["psis", "psir", "identity"
|
|
159
|
+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
|
|
159
160
|
) -> az.InferenceData:
|
|
160
161
|
"""convert flattened samples to arviz InferenceData format.
|
|
161
162
|
|
|
@@ -180,7 +181,7 @@ def convert_flat_trace_to_idata(
|
|
|
180
181
|
arviz inference data object
|
|
181
182
|
"""
|
|
182
183
|
|
|
183
|
-
if importance_sampling
|
|
184
|
+
if importance_sampling is None:
|
|
184
185
|
# samples.ndim == 3 in this case, otherwise ndim == 2
|
|
185
186
|
num_paths, num_pdraws, N = samples.shape
|
|
186
187
|
samples = samples.reshape(-1, N)
|
|
@@ -219,7 +220,7 @@ def convert_flat_trace_to_idata(
|
|
|
219
220
|
fn.trust_input = True
|
|
220
221
|
result = fn(*list(trace.values()))
|
|
221
222
|
|
|
222
|
-
if importance_sampling
|
|
223
|
+
if importance_sampling is None:
|
|
223
224
|
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
|
|
224
225
|
|
|
225
226
|
elif inference_backend == "blackjax":
|
|
@@ -1188,7 +1189,7 @@ class MultiPathfinderResult:
|
|
|
1188
1189
|
elbo_argmax: NDArray | None = None
|
|
1189
1190
|
lbfgs_status: Counter = field(default_factory=Counter)
|
|
1190
1191
|
path_status: Counter = field(default_factory=Counter)
|
|
1191
|
-
importance_sampling: str = "
|
|
1192
|
+
importance_sampling: str | None = "psis"
|
|
1192
1193
|
warnings: list[str] = field(default_factory=list)
|
|
1193
1194
|
pareto_k: float | None = None
|
|
1194
1195
|
|
|
@@ -1257,7 +1258,7 @@ class MultiPathfinderResult:
|
|
|
1257
1258
|
def with_importance_sampling(
|
|
1258
1259
|
self,
|
|
1259
1260
|
num_draws: int,
|
|
1260
|
-
method: Literal["psis", "psir", "identity"
|
|
1261
|
+
method: Literal["psis", "psir", "identity"] | None,
|
|
1261
1262
|
random_seed: int | None = None,
|
|
1262
1263
|
) -> Self:
|
|
1263
1264
|
"""perform importance sampling"""
|
|
@@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
|
|
|
1395
1396
|
|
|
1396
1397
|
path_status_message = {
|
|
1397
1398
|
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
|
|
1398
|
-
PathStatus.
|
|
1399
|
+
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
|
|
1399
1400
|
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
|
|
1400
1401
|
}
|
|
1401
1402
|
|
|
@@ -1423,7 +1424,7 @@ def multipath_pathfinder(
|
|
|
1423
1424
|
num_elbo_draws: int,
|
|
1424
1425
|
jitter: float,
|
|
1425
1426
|
epsilon: float,
|
|
1426
|
-
importance_sampling: Literal["psis", "psir", "identity"
|
|
1427
|
+
importance_sampling: Literal["psis", "psir", "identity"] | None,
|
|
1427
1428
|
progressbar: bool,
|
|
1428
1429
|
concurrent: Literal["thread", "process"] | None,
|
|
1429
1430
|
random_seed: RandomSeed,
|
|
@@ -1459,8 +1460,14 @@ def multipath_pathfinder(
|
|
|
1459
1460
|
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
|
|
1460
1461
|
epsilon: float
|
|
1461
1462
|
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
|
|
1462
|
-
importance_sampling : str, optional
|
|
1463
|
-
|
|
1463
|
+
importance_sampling : str, None, optional
|
|
1464
|
+
Method to apply sampling based on log importance weights (logP - logQ).
|
|
1465
|
+
"psis" : Pareto Smoothed Importance Sampling (default)
|
|
1466
|
+
Recommended for more stable results.
|
|
1467
|
+
"psir" : Pareto Smoothed Importance Resampling
|
|
1468
|
+
Less stable than PSIS.
|
|
1469
|
+
"identity" : Applies log importance weights directly without resampling.
|
|
1470
|
+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
|
|
1464
1471
|
progressbar : bool, optional
|
|
1465
1472
|
Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
|
|
1466
1473
|
random_seed : RandomSeed, optional
|
|
@@ -1482,12 +1489,6 @@ def multipath_pathfinder(
|
|
|
1482
1489
|
The result containing samples and other information from the Multi-Path Pathfinder algorithm.
|
|
1483
1490
|
"""
|
|
1484
1491
|
|
|
1485
|
-
valid_importance_sampling = ["psis", "psir", "identity", "none", None]
|
|
1486
|
-
if importance_sampling is None:
|
|
1487
|
-
importance_sampling = "none"
|
|
1488
|
-
if importance_sampling.lower() not in valid_importance_sampling:
|
|
1489
|
-
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
|
|
1490
|
-
|
|
1491
1492
|
*path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
|
|
1492
1493
|
|
|
1493
1494
|
pathfinder_config = PathfinderConfig(
|
|
@@ -1521,12 +1522,20 @@ def multipath_pathfinder(
|
|
|
1521
1522
|
results = []
|
|
1522
1523
|
compute_start = time.time()
|
|
1523
1524
|
try:
|
|
1524
|
-
|
|
1525
|
+
desc = f"Paths Complete: {{path_idx}}/{num_paths}"
|
|
1526
|
+
progress = CustomProgress(
|
|
1527
|
+
"[progress.description]{task.description}",
|
|
1528
|
+
BarColumn(),
|
|
1529
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
1530
|
+
TimeRemainingColumn(),
|
|
1531
|
+
TextColumn("/"),
|
|
1532
|
+
TimeElapsedColumn(),
|
|
1525
1533
|
console=Console(theme=default_progress_theme),
|
|
1526
1534
|
disable=not progressbar,
|
|
1527
|
-
)
|
|
1528
|
-
|
|
1529
|
-
|
|
1535
|
+
)
|
|
1536
|
+
with progress:
|
|
1537
|
+
task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
|
|
1538
|
+
for path_idx, result in enumerate(generator, start=1):
|
|
1530
1539
|
try:
|
|
1531
1540
|
if isinstance(result, Exception):
|
|
1532
1541
|
raise result
|
|
@@ -1552,7 +1561,14 @@ def multipath_pathfinder(
|
|
|
1552
1561
|
lbfgs_status=LBFGSStatus.LBFGS_FAILED,
|
|
1553
1562
|
)
|
|
1554
1563
|
)
|
|
1555
|
-
|
|
1564
|
+
finally:
|
|
1565
|
+
# TODO: display LBFGS and Path Status in real time
|
|
1566
|
+
progress.update(
|
|
1567
|
+
task,
|
|
1568
|
+
description=desc.format(path_idx=path_idx),
|
|
1569
|
+
completed=path_idx,
|
|
1570
|
+
refresh=True,
|
|
1571
|
+
)
|
|
1556
1572
|
except (KeyboardInterrupt, StopIteration) as e:
|
|
1557
1573
|
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
|
|
1558
1574
|
if isinstance(e, StopIteration):
|
|
@@ -1606,7 +1622,7 @@ def fit_pathfinder(
|
|
|
1606
1622
|
num_elbo_draws: int = 10, # K
|
|
1607
1623
|
jitter: float = 2.0,
|
|
1608
1624
|
epsilon: float = 1e-8,
|
|
1609
|
-
importance_sampling: Literal["psis", "psir", "identity"
|
|
1625
|
+
importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
|
|
1610
1626
|
progressbar: bool = True,
|
|
1611
1627
|
concurrent: Literal["thread", "process"] | None = None,
|
|
1612
1628
|
random_seed: RandomSeed | None = None,
|
|
@@ -1646,8 +1662,15 @@ def fit_pathfinder(
|
|
|
1646
1662
|
Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
|
|
1647
1663
|
epsilon: float
|
|
1648
1664
|
value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
|
|
1649
|
-
importance_sampling : str, optional
|
|
1650
|
-
|
|
1665
|
+
importance_sampling : str, None, optional
|
|
1666
|
+
Method to apply sampling based on log importance weights (logP - logQ).
|
|
1667
|
+
Options are:
|
|
1668
|
+
"psis" : Pareto Smoothed Importance Sampling (default)
|
|
1669
|
+
Recommended for more stable results.
|
|
1670
|
+
"psir" : Pareto Smoothed Importance Resampling
|
|
1671
|
+
Less stable than PSIS.
|
|
1672
|
+
"identity" : Applies log importance weights directly without resampling.
|
|
1673
|
+
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
|
|
1651
1674
|
progressbar : bool, optional
|
|
1652
1675
|
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
|
|
1653
1676
|
random_seed : RandomSeed, optional
|
|
@@ -1674,6 +1697,15 @@ def fit_pathfinder(
|
|
|
1674
1697
|
"""
|
|
1675
1698
|
|
|
1676
1699
|
model = modelcontext(model)
|
|
1700
|
+
|
|
1701
|
+
valid_importance_sampling = {"psis", "psir", "identity", None}
|
|
1702
|
+
|
|
1703
|
+
if importance_sampling is not None:
|
|
1704
|
+
importance_sampling = importance_sampling.lower()
|
|
1705
|
+
|
|
1706
|
+
if importance_sampling not in valid_importance_sampling:
|
|
1707
|
+
raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
|
|
1708
|
+
|
|
1677
1709
|
N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
|
|
1678
1710
|
|
|
1679
1711
|
if maxcor is None:
|
|
@@ -19,7 +19,8 @@ from pymc.model.fgraph import (
|
|
|
19
19
|
model_free_rv,
|
|
20
20
|
model_from_fgraph,
|
|
21
21
|
)
|
|
22
|
-
from pymc.pytensorf import collect_default_updates,
|
|
22
|
+
from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
|
|
23
|
+
from pymc.pytensorf import compile as compile_pymc
|
|
23
24
|
from pymc.util import RandomState, _get_seeds_per_chain
|
|
24
25
|
from pytensor import In, Out
|
|
25
26
|
from pytensor.compile import SharedVariable
|
|
@@ -30,7 +30,7 @@ def compile_statespace(
|
|
|
30
30
|
|
|
31
31
|
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
|
|
32
32
|
|
|
33
|
-
_f = pm.
|
|
33
|
+
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
|
|
34
34
|
|
|
35
35
|
def f(*, draws=1, **params):
|
|
36
36
|
if isinstance(steps, pt.Variable):
|
|
@@ -28,7 +28,6 @@ from pymc_extras.statespace.filters import (
|
|
|
28
28
|
)
|
|
29
29
|
from pymc_extras.statespace.filters.distributions import (
|
|
30
30
|
LinearGaussianStateSpace,
|
|
31
|
-
MvNormalSVD,
|
|
32
31
|
SequenceMvNormal,
|
|
33
32
|
)
|
|
34
33
|
from pymc_extras.statespace.filters.utilities import stabilize
|
|
@@ -707,7 +706,7 @@ class PyMCStateSpace:
|
|
|
707
706
|
with pymc_model:
|
|
708
707
|
for param_name in self.param_names:
|
|
709
708
|
param = getattr(pymc_model, param_name, None)
|
|
710
|
-
if param:
|
|
709
|
+
if param is not None:
|
|
711
710
|
found_params.append(param.name)
|
|
712
711
|
|
|
713
712
|
missing_params = list(set(self.param_names) - set(found_params))
|
|
@@ -746,7 +745,7 @@ class PyMCStateSpace:
|
|
|
746
745
|
with pymc_model:
|
|
747
746
|
for data_name in data_names:
|
|
748
747
|
data = getattr(pymc_model, data_name, None)
|
|
749
|
-
if data:
|
|
748
|
+
if data is not None:
|
|
750
749
|
found_data.append(data.name)
|
|
751
750
|
|
|
752
751
|
missing_data = list(set(data_names) - set(found_data))
|
|
@@ -2233,7 +2232,9 @@ class PyMCStateSpace:
|
|
|
2233
2232
|
if shock_trajectory is None:
|
|
2234
2233
|
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
|
|
2235
2234
|
if Q is not None:
|
|
2236
|
-
init_shock =
|
|
2235
|
+
init_shock = pm.MvNormal(
|
|
2236
|
+
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
|
|
2237
|
+
)
|
|
2237
2238
|
else:
|
|
2238
2239
|
init_shock = pm.Deterministic(
|
|
2239
2240
|
"initial_shock",
|