pymc-extras 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/deserialize.py +10 -4
- pymc_extras/distributions/continuous.py +1 -1
- pymc_extras/distributions/histogram_utils.py +6 -4
- pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- pymc_extras/distributions/timeseries.py +4 -2
- pymc_extras/inference/dadvi/dadvi.py +162 -72
- pymc_extras/inference/laplace_approx/find_map.py +16 -39
- pymc_extras/inference/laplace_approx/idata.py +22 -4
- pymc_extras/inference/laplace_approx/laplace.py +23 -6
- pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras/inference/pathfinder/idata.py +517 -0
- pymc_extras/inference/pathfinder/pathfinder.py +61 -7
- pymc_extras/model/marginal/graph_analysis.py +2 -2
- pymc_extras/model_builder.py +9 -4
- pymc_extras/prior.py +203 -8
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/filters/kalman_filter.py +12 -11
- pymc_extras/statespace/filters/kalman_smoother.py +1 -3
- pymc_extras/statespace/filters/utilities.py +2 -5
- pymc_extras/statespace/models/DFM.py +12 -27
- pymc_extras/statespace/models/ETS.py +190 -198
- pymc_extras/statespace/models/SARIMAX.py +5 -17
- pymc_extras/statespace/models/VARMAX.py +15 -67
- pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- pymc_extras/statespace/models/structural/components/regression.py +4 -26
- pymc_extras/statespace/models/utilities.py +7 -0
- pymc_extras/utils/model_equivalence.py +2 -2
- pymc_extras/utils/prior.py +10 -14
- pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +33 -32
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
pymc_extras/deserialize.py
CHANGED
|
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
|
|
|
13
13
|
|
|
14
14
|
from pymc_extras.deserialize import deserialize
|
|
15
15
|
|
|
16
|
-
prior_class_data = {
|
|
17
|
-
"dist": "Normal",
|
|
18
|
-
"kwargs": {"mu": 0, "sigma": 1}
|
|
19
|
-
}
|
|
16
|
+
prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
|
|
20
17
|
prior = deserialize(prior_class_data)
|
|
21
18
|
# Prior("Normal", mu=0, sigma=1)
|
|
22
19
|
|
|
@@ -26,6 +23,7 @@ Register custom class deserialization:
|
|
|
26
23
|
|
|
27
24
|
from pymc_extras.deserialize import register_deserialization
|
|
28
25
|
|
|
26
|
+
|
|
29
27
|
class MyClass:
|
|
30
28
|
def __init__(self, value: int):
|
|
31
29
|
self.value = value
|
|
@@ -34,6 +32,7 @@ Register custom class deserialization:
|
|
|
34
32
|
# Example of what the to_dict method might look like.
|
|
35
33
|
return {"value": self.value}
|
|
36
34
|
|
|
35
|
+
|
|
37
36
|
register_deserialization(
|
|
38
37
|
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
39
38
|
deserialize=lambda data: MyClass(value=data["value"]),
|
|
@@ -80,18 +79,23 @@ class Deserializer:
|
|
|
80
79
|
|
|
81
80
|
from typing import Any
|
|
82
81
|
|
|
82
|
+
|
|
83
83
|
class MyClass:
|
|
84
84
|
def __init__(self, value: int):
|
|
85
85
|
self.value = value
|
|
86
86
|
|
|
87
|
+
|
|
87
88
|
from pymc_extras.deserialize import Deserializer
|
|
88
89
|
|
|
90
|
+
|
|
89
91
|
def is_type(data: Any) -> bool:
|
|
90
92
|
return data.keys() == {"value"} and isinstance(data["value"], int)
|
|
91
93
|
|
|
94
|
+
|
|
92
95
|
def deserialize(data: dict) -> MyClass:
|
|
93
96
|
return MyClass(value=data["value"])
|
|
94
97
|
|
|
98
|
+
|
|
95
99
|
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
|
|
96
100
|
|
|
97
101
|
"""
|
|
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
|
196
200
|
|
|
197
201
|
from pymc_extras.deserialize import register_deserialization
|
|
198
202
|
|
|
203
|
+
|
|
199
204
|
class MyClass:
|
|
200
205
|
def __init__(self, value: int):
|
|
201
206
|
self.value = value
|
|
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
|
204
209
|
# Example of what the to_dict method might look like.
|
|
205
210
|
return {"value": self.value}
|
|
206
211
|
|
|
212
|
+
|
|
207
213
|
register_deserialization(
|
|
208
214
|
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
209
215
|
deserialize=lambda data: MyClass(value=data["value"]),
|
|
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
|
130
130
|
... m = pm.Normal("m", dims="tests")
|
|
131
131
|
... s = pm.LogNormal("s", dims="tests")
|
|
132
132
|
... pot = pmx.distributions.histogram_approximation(
|
|
133
|
-
... "pot", pm.Normal.dist(m, s),
|
|
134
|
-
... observed=measurements, n_quantiles=50
|
|
133
|
+
... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
|
|
135
134
|
... )
|
|
136
135
|
|
|
137
136
|
For special cases like Zero Inflation in Continuous variables there is a flag.
|
|
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
|
143
142
|
... m = pm.Normal("m", dims="tests")
|
|
144
143
|
... s = pm.LogNormal("s", dims="tests")
|
|
145
144
|
... pot = pmx.distributions.histogram_approximation(
|
|
146
|
-
... "pot",
|
|
147
|
-
...
|
|
145
|
+
... "pot",
|
|
146
|
+
... pm.Normal.dist(m, s),
|
|
147
|
+
... observed=measurements,
|
|
148
|
+
... n_quantiles=50,
|
|
149
|
+
... zero_inflation=True,
|
|
148
150
|
... )
|
|
149
151
|
"""
|
|
150
152
|
try:
|
|
@@ -305,6 +305,7 @@ def R2D2M2CP(
|
|
|
305
305
|
import pymc_extras as pmx
|
|
306
306
|
import pymc as pm
|
|
307
307
|
import numpy as np
|
|
308
|
+
|
|
308
309
|
X = np.random.randn(10, 3)
|
|
309
310
|
b = np.random.randn(3)
|
|
310
311
|
y = X @ b + np.random.randn(10) * 0.04 + 5
|
|
@@ -339,7 +340,7 @@ def R2D2M2CP(
|
|
|
339
340
|
# "c" - a must have in the relation
|
|
340
341
|
variables_importance=[10, 1, 34],
|
|
341
342
|
# NOTE: try both
|
|
342
|
-
centered=True
|
|
343
|
+
centered=True,
|
|
343
344
|
)
|
|
344
345
|
# intercept prior centering should be around prior predictive mean
|
|
345
346
|
intercept = y.mean()
|
|
@@ -365,7 +366,7 @@ def R2D2M2CP(
|
|
|
365
366
|
r2_std=0.2,
|
|
366
367
|
# NOTE: if you know where a variable should go
|
|
367
368
|
# if you do not know, leave as 0.5
|
|
368
|
-
centered=False
|
|
369
|
+
centered=False,
|
|
369
370
|
)
|
|
370
371
|
# intercept prior centering should be around prior predictive mean
|
|
371
372
|
intercept = y.mean()
|
|
@@ -394,7 +395,7 @@ def R2D2M2CP(
|
|
|
394
395
|
# if you do not know, leave as 0.5
|
|
395
396
|
positive_probs=[0.8, 0.5, 0.1],
|
|
396
397
|
# NOTE: try both
|
|
397
|
-
centered=True
|
|
398
|
+
centered=True,
|
|
398
399
|
)
|
|
399
400
|
intercept = y.mean()
|
|
400
401
|
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
|
|
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
|
|
|
113
113
|
|
|
114
114
|
with pm.Model() as markov_chain:
|
|
115
115
|
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
|
|
116
|
-
init_dist = pm.Categorical.dist(p
|
|
117
|
-
markov_chain = pmx.DiscreteMarkovChain(
|
|
116
|
+
init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
|
|
117
|
+
markov_chain = pmx.DiscreteMarkovChain(
|
|
118
|
+
"markov_chain", P=P, init_dist=init_dist, shape=(100,)
|
|
119
|
+
)
|
|
118
120
|
|
|
119
121
|
"""
|
|
120
122
|
|
|
@@ -5,7 +5,7 @@ import pytensor
|
|
|
5
5
|
import pytensor.tensor as pt
|
|
6
6
|
import xarray
|
|
7
7
|
|
|
8
|
-
from better_optimize import minimize
|
|
8
|
+
from better_optimize import basinhopping, minimize
|
|
9
9
|
from better_optimize.constants import minimize_method
|
|
10
10
|
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
|
|
11
11
|
from pymc.backends.arviz import (
|
|
@@ -13,33 +13,40 @@ from pymc.backends.arviz import (
|
|
|
13
13
|
apply_function_over_dataset,
|
|
14
14
|
coords_and_dims_for_inferencedata,
|
|
15
15
|
)
|
|
16
|
+
from pymc.blocking import RaveledVars
|
|
16
17
|
from pymc.util import RandomSeed, get_default_varnames
|
|
17
18
|
from pytensor.tensor.variable import TensorVariable
|
|
18
19
|
|
|
20
|
+
from pymc_extras.inference.laplace_approx.idata import (
|
|
21
|
+
add_data_to_inference_data,
|
|
22
|
+
add_optimizer_result_to_inference_data,
|
|
23
|
+
)
|
|
19
24
|
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
|
|
20
25
|
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
21
|
-
|
|
26
|
+
scipy_optimize_funcs_from_loss,
|
|
27
|
+
set_optimizer_function_defaults,
|
|
22
28
|
)
|
|
23
29
|
|
|
24
30
|
|
|
25
31
|
def fit_dadvi(
|
|
26
32
|
model: Model | None = None,
|
|
27
33
|
n_fixed_draws: int = 30,
|
|
28
|
-
random_seed: RandomSeed = None,
|
|
29
34
|
n_draws: int = 1000,
|
|
30
|
-
|
|
35
|
+
include_transformed: bool = False,
|
|
31
36
|
optimizer_method: minimize_method = "trust-ncg",
|
|
32
|
-
use_grad: bool =
|
|
33
|
-
use_hessp: bool =
|
|
34
|
-
use_hess: bool =
|
|
35
|
-
|
|
37
|
+
use_grad: bool | None = None,
|
|
38
|
+
use_hessp: bool | None = None,
|
|
39
|
+
use_hess: bool | None = None,
|
|
40
|
+
gradient_backend: str = "pytensor",
|
|
41
|
+
compile_kwargs: dict | None = None,
|
|
42
|
+
random_seed: RandomSeed = None,
|
|
43
|
+
progressbar: bool = True,
|
|
44
|
+
**optimizer_kwargs,
|
|
36
45
|
) -> az.InferenceData:
|
|
37
46
|
"""
|
|
38
|
-
Does inference using
|
|
39
|
-
variational inference), DADVI for short.
|
|
47
|
+
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
|
|
40
48
|
|
|
41
|
-
For full details see the paper cited in the references:
|
|
42
|
-
https://www.jmlr.org/papers/v25/23-1015.html
|
|
49
|
+
For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
|
|
43
50
|
|
|
44
51
|
Parameters
|
|
45
52
|
----------
|
|
@@ -47,46 +54,48 @@ def fit_dadvi(
|
|
|
47
54
|
The PyMC model to be fit. If None, the current model context is used.
|
|
48
55
|
|
|
49
56
|
n_fixed_draws : int
|
|
50
|
-
The number of fixed draws to use for the optimisation. More
|
|
51
|
-
|
|
52
|
-
increase inference time. Usually, the default of 30 is a good
|
|
53
|
-
tradeoff.between speed and accuracy.
|
|
57
|
+
The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
|
|
58
|
+
also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
|
|
54
59
|
|
|
55
60
|
random_seed: int
|
|
56
|
-
The random seed to use for the fixed draws. Running the optimisation
|
|
57
|
-
|
|
61
|
+
The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
|
|
62
|
+
the same result.
|
|
58
63
|
|
|
59
64
|
n_draws: int
|
|
60
65
|
The number of draws to return from the variational approximation.
|
|
61
66
|
|
|
62
|
-
|
|
63
|
-
Whether or not to keep the unconstrained variables (such as
|
|
64
|
-
|
|
67
|
+
include_transformed: bool
|
|
68
|
+
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
|
|
69
|
+
output.
|
|
65
70
|
|
|
66
71
|
optimizer_method: str
|
|
67
|
-
Which optimization method to use. The function calls
|
|
68
|
-
|
|
69
|
-
be
|
|
70
|
-
|
|
71
|
-
as L-BFGS-B might be faster but potentially more brittle and
|
|
72
|
-
may not converge exactly to the optimum.
|
|
73
|
-
|
|
74
|
-
minimize_kwargs:
|
|
75
|
-
Additional keyword arguments to pass to the
|
|
76
|
-
``scipy.optimize.minimize`` function. See the documentation of
|
|
77
|
-
that function for details.
|
|
72
|
+
Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
|
|
73
|
+
can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
|
|
74
|
+
Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
|
|
75
|
+
the optimum.
|
|
78
76
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
`scipy.optimize.minimize` (where it is referred to as `jac`).
|
|
77
|
+
gradient_backend: str
|
|
78
|
+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
|
|
82
79
|
|
|
83
|
-
|
|
80
|
+
compile_kwargs: dict, optional
|
|
81
|
+
Additional keyword arguments to pass to `pytensor.function`
|
|
82
|
+
|
|
83
|
+
use_grad: bool, optional
|
|
84
|
+
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
|
|
85
|
+
|
|
86
|
+
use_hessp: bool, optional
|
|
84
87
|
If True, pass the hessian vector product to `scipy.optimize.minimize`.
|
|
85
88
|
|
|
86
|
-
use_hess:
|
|
87
|
-
If True, pass the hessian to `scipy.optimize.minimize`. Note that
|
|
88
|
-
|
|
89
|
-
|
|
89
|
+
use_hess: bool, optional
|
|
90
|
+
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
|
|
91
|
+
computation can be slow and memory-intensive if there are many parameters.
|
|
92
|
+
|
|
93
|
+
progressbar: bool
|
|
94
|
+
Whether or not to show a progress bar during optimization. Default is True.
|
|
95
|
+
|
|
96
|
+
optimizer_kwargs:
|
|
97
|
+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
|
|
98
|
+
that function for details.
|
|
90
99
|
|
|
91
100
|
Returns
|
|
92
101
|
-------
|
|
@@ -95,16 +104,25 @@ def fit_dadvi(
|
|
|
95
104
|
|
|
96
105
|
References
|
|
97
106
|
----------
|
|
98
|
-
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
|
|
99
|
-
|
|
100
|
-
Accurate, and Even More Black Box. Journal of Machine Learning
|
|
101
|
-
Research, 25(18), 1–39.
|
|
107
|
+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
|
|
108
|
+
Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
|
|
102
109
|
"""
|
|
103
110
|
|
|
104
111
|
model = pymc.modelcontext(model) if model is None else model
|
|
112
|
+
do_basinhopping = optimizer_method == "basinhopping"
|
|
113
|
+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
114
|
+
|
|
115
|
+
if do_basinhopping:
|
|
116
|
+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
|
|
117
|
+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
|
|
118
|
+
# if one isn't provided.
|
|
119
|
+
|
|
120
|
+
optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
|
|
121
|
+
minimizer_kwargs["method"] = optimizer_method
|
|
105
122
|
|
|
106
123
|
initial_point_dict = model.initial_point()
|
|
107
|
-
|
|
124
|
+
initial_point = DictToArrayBijection.map(initial_point_dict)
|
|
125
|
+
n_params = initial_point.data.shape[0]
|
|
108
126
|
|
|
109
127
|
var_params, objective = create_dadvi_graph(
|
|
110
128
|
model,
|
|
@@ -113,31 +131,65 @@ def fit_dadvi(
|
|
|
113
131
|
n_params=n_params,
|
|
114
132
|
)
|
|
115
133
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
[var_params],
|
|
119
|
-
compute_grad=use_grad,
|
|
120
|
-
compute_hessp=use_hessp,
|
|
121
|
-
compute_hess=use_hess,
|
|
134
|
+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
135
|
+
optimizer_method, use_grad, use_hess, use_hessp
|
|
122
136
|
)
|
|
123
137
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
138
|
+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
|
|
139
|
+
loss=objective,
|
|
140
|
+
inputs=[var_params],
|
|
141
|
+
initial_point_dict=None,
|
|
142
|
+
use_grad=use_grad,
|
|
143
|
+
use_hessp=use_hessp,
|
|
144
|
+
use_hess=use_hess,
|
|
145
|
+
gradient_backend=gradient_backend,
|
|
146
|
+
compile_kwargs=compile_kwargs,
|
|
147
|
+
inputs_are_flat=True,
|
|
148
|
+
)
|
|
132
149
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
150
|
+
dadvi_initial_point = {
|
|
151
|
+
f"{var_name}_mu": np.zeros_like(value).ravel()
|
|
152
|
+
for var_name, value in initial_point_dict.items()
|
|
153
|
+
}
|
|
154
|
+
dadvi_initial_point.update(
|
|
155
|
+
{
|
|
156
|
+
f"{var_name}_sigma__log": np.zeros_like(value).ravel()
|
|
157
|
+
for var_name, value in initial_point_dict.items()
|
|
158
|
+
}
|
|
139
159
|
)
|
|
140
160
|
|
|
161
|
+
dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
|
|
162
|
+
args = optimizer_kwargs.pop("args", ())
|
|
163
|
+
|
|
164
|
+
if do_basinhopping:
|
|
165
|
+
if "args" not in minimizer_kwargs:
|
|
166
|
+
minimizer_kwargs["args"] = args
|
|
167
|
+
if "hessp" not in minimizer_kwargs:
|
|
168
|
+
minimizer_kwargs["hessp"] = f_hessp
|
|
169
|
+
if "method" not in minimizer_kwargs:
|
|
170
|
+
minimizer_kwargs["method"] = optimizer_method
|
|
171
|
+
|
|
172
|
+
result = basinhopping(
|
|
173
|
+
func=f_fused,
|
|
174
|
+
x0=dadvi_initial_point.data,
|
|
175
|
+
progressbar=progressbar,
|
|
176
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
177
|
+
**optimizer_kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
result = minimize(
|
|
182
|
+
f=f_fused,
|
|
183
|
+
x0=dadvi_initial_point.data,
|
|
184
|
+
args=args,
|
|
185
|
+
method=optimizer_method,
|
|
186
|
+
hessp=f_hessp,
|
|
187
|
+
progressbar=progressbar,
|
|
188
|
+
**optimizer_kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
|
|
192
|
+
|
|
141
193
|
opt_var_params = result.x
|
|
142
194
|
opt_means, opt_log_sds = np.split(opt_var_params, 2)
|
|
143
195
|
|
|
@@ -148,9 +200,29 @@ def fit_dadvi(
|
|
|
148
200
|
draws = opt_means + draws_raw * np.exp(opt_log_sds)
|
|
149
201
|
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
|
|
150
202
|
|
|
151
|
-
|
|
203
|
+
idata = dadvi_result_to_idata(
|
|
204
|
+
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
|
|
205
|
+
)
|
|
152
206
|
|
|
153
|
-
|
|
207
|
+
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
|
|
208
|
+
var_name_to_model_var.update(
|
|
209
|
+
{f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
idata = add_optimizer_result_to_inference_data(
|
|
213
|
+
idata=idata,
|
|
214
|
+
result=result,
|
|
215
|
+
method=optimizer_method,
|
|
216
|
+
mu=raveled_optimized,
|
|
217
|
+
model=model,
|
|
218
|
+
var_name_to_model_var=var_name_to_model_var,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
idata = add_data_to_inference_data(
|
|
222
|
+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return idata
|
|
154
226
|
|
|
155
227
|
|
|
156
228
|
def create_dadvi_graph(
|
|
@@ -213,10 +285,11 @@ def create_dadvi_graph(
|
|
|
213
285
|
return var_params, objective
|
|
214
286
|
|
|
215
287
|
|
|
216
|
-
def
|
|
288
|
+
def dadvi_result_to_idata(
|
|
217
289
|
unstacked_draws: xarray.Dataset,
|
|
218
290
|
model: Model,
|
|
219
|
-
|
|
291
|
+
include_transformed: bool = False,
|
|
292
|
+
progressbar: bool = True,
|
|
220
293
|
):
|
|
221
294
|
"""
|
|
222
295
|
Transforms the unconstrained draws back into the constrained space.
|
|
@@ -232,9 +305,12 @@ def transform_draws(
|
|
|
232
305
|
n_draws: int
|
|
233
306
|
The number of draws to return from the variational approximation.
|
|
234
307
|
|
|
235
|
-
|
|
308
|
+
include_transformed: bool
|
|
236
309
|
Whether or not to keep the unconstrained variables in the output.
|
|
237
310
|
|
|
311
|
+
progressbar: bool
|
|
312
|
+
Whether or not to show a progress bar during the transformation. Default is True.
|
|
313
|
+
|
|
238
314
|
Returns
|
|
239
315
|
-------
|
|
240
316
|
:class:`~arviz.InferenceData`
|
|
@@ -243,7 +319,7 @@ def transform_draws(
|
|
|
243
319
|
|
|
244
320
|
filtered_var_names = model.unobserved_value_vars
|
|
245
321
|
vars_to_sample = list(
|
|
246
|
-
get_default_varnames(filtered_var_names, include_transformed=
|
|
322
|
+
get_default_varnames(filtered_var_names, include_transformed=include_transformed)
|
|
247
323
|
)
|
|
248
324
|
fn = pytensor.function(model.value_vars, vars_to_sample)
|
|
249
325
|
point_func = PointFunc(fn)
|
|
@@ -256,6 +332,20 @@ def transform_draws(
|
|
|
256
332
|
output_var_names=[x.name for x in vars_to_sample],
|
|
257
333
|
coords=coords,
|
|
258
334
|
dims=dims,
|
|
335
|
+
progressbar=progressbar,
|
|
259
336
|
)
|
|
260
337
|
|
|
261
|
-
|
|
338
|
+
constrained_names = [
|
|
339
|
+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
|
|
340
|
+
]
|
|
341
|
+
all_varnames = [
|
|
342
|
+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
343
|
+
]
|
|
344
|
+
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
|
|
345
|
+
|
|
346
|
+
idata = az.InferenceData(posterior=transformed_result[constrained_names])
|
|
347
|
+
|
|
348
|
+
if unconstrained_names and include_transformed:
|
|
349
|
+
idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
|
|
350
|
+
|
|
351
|
+
return idata
|
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
|
7
7
|
import pymc as pm
|
|
8
8
|
|
|
9
9
|
from better_optimize import basinhopping, minimize
|
|
10
|
-
from better_optimize.constants import
|
|
10
|
+
from better_optimize.constants import minimize_method
|
|
11
11
|
from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
12
12
|
from pymc.initial_point import make_initial_point_fn
|
|
13
13
|
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
@@ -24,40 +24,12 @@ from pymc_extras.inference.laplace_approx.idata import (
|
|
|
24
24
|
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
25
25
|
GradientBackend,
|
|
26
26
|
scipy_optimize_funcs_from_loss,
|
|
27
|
+
set_optimizer_function_defaults,
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
_log = logging.getLogger(__name__)
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
33
|
-
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
34
|
-
|
|
35
|
-
if use_hess and use_hessp:
|
|
36
|
-
_log.warning(
|
|
37
|
-
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
38
|
-
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
39
|
-
'Setting "use_hess" to False.'
|
|
40
|
-
)
|
|
41
|
-
use_hess = False
|
|
42
|
-
|
|
43
|
-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
44
|
-
|
|
45
|
-
if use_hessp is not None and use_hess is None:
|
|
46
|
-
use_hess = not use_hessp
|
|
47
|
-
|
|
48
|
-
elif use_hess is not None and use_hessp is None:
|
|
49
|
-
use_hessp = not use_hess
|
|
50
|
-
|
|
51
|
-
elif use_hessp is None and use_hess is None:
|
|
52
|
-
use_hessp = method_info["uses_hessp"]
|
|
53
|
-
use_hess = method_info["uses_hess"]
|
|
54
|
-
if use_hessp and use_hess:
|
|
55
|
-
# If a method could use either hess or hessp, we default to using hessp
|
|
56
|
-
use_hess = False
|
|
57
|
-
|
|
58
|
-
return use_grad, use_hess, use_hessp
|
|
59
|
-
|
|
60
|
-
|
|
61
33
|
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
62
34
|
"""
|
|
63
35
|
Compute the nearest positive semi-definite matrix to a given matrix.
|
|
@@ -196,6 +168,7 @@ def find_MAP(
|
|
|
196
168
|
jitter_rvs: list[TensorVariable] | None = None,
|
|
197
169
|
progressbar: bool = True,
|
|
198
170
|
include_transformed: bool = True,
|
|
171
|
+
freeze_model: bool = True,
|
|
199
172
|
gradient_backend: GradientBackend = "pytensor",
|
|
200
173
|
compile_kwargs: dict | None = None,
|
|
201
174
|
compute_hessian: bool = False,
|
|
@@ -238,6 +211,10 @@ def find_MAP(
|
|
|
238
211
|
Whether to display a progress bar during optimization. Defaults to True.
|
|
239
212
|
include_transformed: bool, optional
|
|
240
213
|
Whether to include transformed variable values in the returned dictionary. Defaults to True.
|
|
214
|
+
freeze_model: bool, optional
|
|
215
|
+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
|
|
216
|
+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
|
|
217
|
+
True.
|
|
241
218
|
gradient_backend: str, default "pytensor"
|
|
242
219
|
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
|
|
243
220
|
compute_hessian: bool
|
|
@@ -257,11 +234,13 @@ def find_MAP(
|
|
|
257
234
|
Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
|
|
258
235
|
latent variables, and optimizer results.
|
|
259
236
|
"""
|
|
260
|
-
model = pm.modelcontext(model) if model is None else model
|
|
261
|
-
frozen_model = freeze_dims_and_data(model)
|
|
262
237
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
238
|
+
model = pm.modelcontext(model) if model is None else model
|
|
239
|
+
|
|
240
|
+
if freeze_model:
|
|
241
|
+
model = freeze_dims_and_data(model)
|
|
263
242
|
|
|
264
|
-
initial_params = _make_initial_point(
|
|
243
|
+
initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
|
|
265
244
|
|
|
266
245
|
do_basinhopping = method == "basinhopping"
|
|
267
246
|
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
@@ -279,8 +258,8 @@ def find_MAP(
|
|
|
279
258
|
)
|
|
280
259
|
|
|
281
260
|
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
|
|
282
|
-
loss=-
|
|
283
|
-
inputs=
|
|
261
|
+
loss=-model.logp(),
|
|
262
|
+
inputs=model.continuous_value_vars + model.discrete_value_vars,
|
|
284
263
|
initial_point_dict=DictToArrayBijection.rmap(initial_params),
|
|
285
264
|
use_grad=use_grad,
|
|
286
265
|
use_hess=use_hess,
|
|
@@ -344,12 +323,10 @@ def find_MAP(
|
|
|
344
323
|
}
|
|
345
324
|
|
|
346
325
|
idata = map_results_to_inference_data(
|
|
347
|
-
map_point=optimized_point, model=
|
|
326
|
+
map_point=optimized_point, model=model, include_transformed=include_transformed
|
|
348
327
|
)
|
|
349
328
|
|
|
350
|
-
idata = add_fit_to_inference_data(
|
|
351
|
-
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
|
|
352
|
-
)
|
|
329
|
+
idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
|
|
353
330
|
|
|
354
331
|
idata = add_optimizer_result_to_inference_data(
|
|
355
332
|
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
|