pymc-extras 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/deserialize.py +10 -4
- pymc_extras/distributions/continuous.py +1 -1
- pymc_extras/distributions/histogram_utils.py +6 -4
- pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- pymc_extras/distributions/timeseries.py +4 -2
- pymc_extras/inference/__init__.py +8 -1
- pymc_extras/inference/dadvi/__init__.py +0 -0
- pymc_extras/inference/dadvi/dadvi.py +351 -0
- pymc_extras/inference/fit.py +5 -0
- pymc_extras/inference/laplace_approx/find_map.py +32 -47
- pymc_extras/inference/laplace_approx/idata.py +27 -6
- pymc_extras/inference/laplace_approx/laplace.py +24 -6
- pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras/inference/pathfinder/idata.py +517 -0
- pymc_extras/inference/pathfinder/pathfinder.py +61 -7
- pymc_extras/model/marginal/graph_analysis.py +2 -2
- pymc_extras/model_builder.py +9 -4
- pymc_extras/prior.py +203 -8
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/filters/kalman_filter.py +12 -11
- pymc_extras/statespace/filters/kalman_smoother.py +1 -3
- pymc_extras/statespace/filters/utilities.py +2 -5
- pymc_extras/statespace/models/DFM.py +834 -0
- pymc_extras/statespace/models/ETS.py +190 -198
- pymc_extras/statespace/models/SARIMAX.py +9 -21
- pymc_extras/statespace/models/VARMAX.py +22 -74
- pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- pymc_extras/statespace/models/structural/components/regression.py +4 -26
- pymc_extras/statespace/models/utilities.py +7 -0
- pymc_extras/statespace/utils/constants.py +3 -1
- pymc_extras/utils/model_equivalence.py +2 -2
- pymc_extras/utils/prior.py +10 -14
- pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
- {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +37 -33
- {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
- {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
pymc_extras/deserialize.py
CHANGED
|
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
|
|
|
13
13
|
|
|
14
14
|
from pymc_extras.deserialize import deserialize
|
|
15
15
|
|
|
16
|
-
prior_class_data = {
|
|
17
|
-
"dist": "Normal",
|
|
18
|
-
"kwargs": {"mu": 0, "sigma": 1}
|
|
19
|
-
}
|
|
16
|
+
prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
|
|
20
17
|
prior = deserialize(prior_class_data)
|
|
21
18
|
# Prior("Normal", mu=0, sigma=1)
|
|
22
19
|
|
|
@@ -26,6 +23,7 @@ Register custom class deserialization:
|
|
|
26
23
|
|
|
27
24
|
from pymc_extras.deserialize import register_deserialization
|
|
28
25
|
|
|
26
|
+
|
|
29
27
|
class MyClass:
|
|
30
28
|
def __init__(self, value: int):
|
|
31
29
|
self.value = value
|
|
@@ -34,6 +32,7 @@ Register custom class deserialization:
|
|
|
34
32
|
# Example of what the to_dict method might look like.
|
|
35
33
|
return {"value": self.value}
|
|
36
34
|
|
|
35
|
+
|
|
37
36
|
register_deserialization(
|
|
38
37
|
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
39
38
|
deserialize=lambda data: MyClass(value=data["value"]),
|
|
@@ -80,18 +79,23 @@ class Deserializer:
|
|
|
80
79
|
|
|
81
80
|
from typing import Any
|
|
82
81
|
|
|
82
|
+
|
|
83
83
|
class MyClass:
|
|
84
84
|
def __init__(self, value: int):
|
|
85
85
|
self.value = value
|
|
86
86
|
|
|
87
|
+
|
|
87
88
|
from pymc_extras.deserialize import Deserializer
|
|
88
89
|
|
|
90
|
+
|
|
89
91
|
def is_type(data: Any) -> bool:
|
|
90
92
|
return data.keys() == {"value"} and isinstance(data["value"], int)
|
|
91
93
|
|
|
94
|
+
|
|
92
95
|
def deserialize(data: dict) -> MyClass:
|
|
93
96
|
return MyClass(value=data["value"])
|
|
94
97
|
|
|
98
|
+
|
|
95
99
|
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
|
|
96
100
|
|
|
97
101
|
"""
|
|
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
|
196
200
|
|
|
197
201
|
from pymc_extras.deserialize import register_deserialization
|
|
198
202
|
|
|
203
|
+
|
|
199
204
|
class MyClass:
|
|
200
205
|
def __init__(self, value: int):
|
|
201
206
|
self.value = value
|
|
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
|
204
209
|
# Example of what the to_dict method might look like.
|
|
205
210
|
return {"value": self.value}
|
|
206
211
|
|
|
212
|
+
|
|
207
213
|
register_deserialization(
|
|
208
214
|
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
209
215
|
deserialize=lambda data: MyClass(value=data["value"]),
|
|
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
|
130
130
|
... m = pm.Normal("m", dims="tests")
|
|
131
131
|
... s = pm.LogNormal("s", dims="tests")
|
|
132
132
|
... pot = pmx.distributions.histogram_approximation(
|
|
133
|
-
... "pot", pm.Normal.dist(m, s),
|
|
134
|
-
... observed=measurements, n_quantiles=50
|
|
133
|
+
... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
|
|
135
134
|
... )
|
|
136
135
|
|
|
137
136
|
For special cases like Zero Inflation in Continuous variables there is a flag.
|
|
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
|
143
142
|
... m = pm.Normal("m", dims="tests")
|
|
144
143
|
... s = pm.LogNormal("s", dims="tests")
|
|
145
144
|
... pot = pmx.distributions.histogram_approximation(
|
|
146
|
-
... "pot",
|
|
147
|
-
...
|
|
145
|
+
... "pot",
|
|
146
|
+
... pm.Normal.dist(m, s),
|
|
147
|
+
... observed=measurements,
|
|
148
|
+
... n_quantiles=50,
|
|
149
|
+
... zero_inflation=True,
|
|
148
150
|
... )
|
|
149
151
|
"""
|
|
150
152
|
try:
|
|
@@ -305,6 +305,7 @@ def R2D2M2CP(
|
|
|
305
305
|
import pymc_extras as pmx
|
|
306
306
|
import pymc as pm
|
|
307
307
|
import numpy as np
|
|
308
|
+
|
|
308
309
|
X = np.random.randn(10, 3)
|
|
309
310
|
b = np.random.randn(3)
|
|
310
311
|
y = X @ b + np.random.randn(10) * 0.04 + 5
|
|
@@ -339,7 +340,7 @@ def R2D2M2CP(
|
|
|
339
340
|
# "c" - a must have in the relation
|
|
340
341
|
variables_importance=[10, 1, 34],
|
|
341
342
|
# NOTE: try both
|
|
342
|
-
centered=True
|
|
343
|
+
centered=True,
|
|
343
344
|
)
|
|
344
345
|
# intercept prior centering should be around prior predictive mean
|
|
345
346
|
intercept = y.mean()
|
|
@@ -365,7 +366,7 @@ def R2D2M2CP(
|
|
|
365
366
|
r2_std=0.2,
|
|
366
367
|
# NOTE: if you know where a variable should go
|
|
367
368
|
# if you do not know, leave as 0.5
|
|
368
|
-
centered=False
|
|
369
|
+
centered=False,
|
|
369
370
|
)
|
|
370
371
|
# intercept prior centering should be around prior predictive mean
|
|
371
372
|
intercept = y.mean()
|
|
@@ -394,7 +395,7 @@ def R2D2M2CP(
|
|
|
394
395
|
# if you do not know, leave as 0.5
|
|
395
396
|
positive_probs=[0.8, 0.5, 0.1],
|
|
396
397
|
# NOTE: try both
|
|
397
|
-
centered=True
|
|
398
|
+
centered=True,
|
|
398
399
|
)
|
|
399
400
|
intercept = y.mean()
|
|
400
401
|
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
|
|
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
|
|
|
113
113
|
|
|
114
114
|
with pm.Model() as markov_chain:
|
|
115
115
|
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
|
|
116
|
-
init_dist = pm.Categorical.dist(p
|
|
117
|
-
markov_chain = pmx.DiscreteMarkovChain(
|
|
116
|
+
init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
|
|
117
|
+
markov_chain = pmx.DiscreteMarkovChain(
|
|
118
|
+
"markov_chain", P=P, init_dist=init_dist, shape=(100,)
|
|
119
|
+
)
|
|
118
120
|
|
|
119
121
|
"""
|
|
120
122
|
|
|
@@ -12,9 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from pymc_extras.inference.dadvi.dadvi import fit_dadvi
|
|
15
16
|
from pymc_extras.inference.fit import fit
|
|
16
17
|
from pymc_extras.inference.laplace_approx.find_map import find_MAP
|
|
17
18
|
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
|
|
18
19
|
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
|
|
19
20
|
|
|
20
|
-
__all__ = [
|
|
21
|
+
__all__ = [
|
|
22
|
+
"find_MAP",
|
|
23
|
+
"fit",
|
|
24
|
+
"fit_laplace",
|
|
25
|
+
"fit_pathfinder",
|
|
26
|
+
"fit_dadvi",
|
|
27
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import arviz as az
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pymc
|
|
4
|
+
import pytensor
|
|
5
|
+
import pytensor.tensor as pt
|
|
6
|
+
import xarray
|
|
7
|
+
|
|
8
|
+
from better_optimize import basinhopping, minimize
|
|
9
|
+
from better_optimize.constants import minimize_method
|
|
10
|
+
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
|
|
11
|
+
from pymc.backends.arviz import (
|
|
12
|
+
PointFunc,
|
|
13
|
+
apply_function_over_dataset,
|
|
14
|
+
coords_and_dims_for_inferencedata,
|
|
15
|
+
)
|
|
16
|
+
from pymc.blocking import RaveledVars
|
|
17
|
+
from pymc.util import RandomSeed, get_default_varnames
|
|
18
|
+
from pytensor.tensor.variable import TensorVariable
|
|
19
|
+
|
|
20
|
+
from pymc_extras.inference.laplace_approx.idata import (
|
|
21
|
+
add_data_to_inference_data,
|
|
22
|
+
add_optimizer_result_to_inference_data,
|
|
23
|
+
)
|
|
24
|
+
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
|
|
25
|
+
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
26
|
+
scipy_optimize_funcs_from_loss,
|
|
27
|
+
set_optimizer_function_defaults,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def fit_dadvi(
|
|
32
|
+
model: Model | None = None,
|
|
33
|
+
n_fixed_draws: int = 30,
|
|
34
|
+
n_draws: int = 1000,
|
|
35
|
+
include_transformed: bool = False,
|
|
36
|
+
optimizer_method: minimize_method = "trust-ncg",
|
|
37
|
+
use_grad: bool | None = None,
|
|
38
|
+
use_hessp: bool | None = None,
|
|
39
|
+
use_hess: bool | None = None,
|
|
40
|
+
gradient_backend: str = "pytensor",
|
|
41
|
+
compile_kwargs: dict | None = None,
|
|
42
|
+
random_seed: RandomSeed = None,
|
|
43
|
+
progressbar: bool = True,
|
|
44
|
+
**optimizer_kwargs,
|
|
45
|
+
) -> az.InferenceData:
|
|
46
|
+
"""
|
|
47
|
+
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
|
|
48
|
+
|
|
49
|
+
For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
model : pm.Model
|
|
54
|
+
The PyMC model to be fit. If None, the current model context is used.
|
|
55
|
+
|
|
56
|
+
n_fixed_draws : int
|
|
57
|
+
The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
|
|
58
|
+
also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
|
|
59
|
+
|
|
60
|
+
random_seed: int
|
|
61
|
+
The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
|
|
62
|
+
the same result.
|
|
63
|
+
|
|
64
|
+
n_draws: int
|
|
65
|
+
The number of draws to return from the variational approximation.
|
|
66
|
+
|
|
67
|
+
include_transformed: bool
|
|
68
|
+
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
|
|
69
|
+
output.
|
|
70
|
+
|
|
71
|
+
optimizer_method: str
|
|
72
|
+
Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
|
|
73
|
+
can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
|
|
74
|
+
Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
|
|
75
|
+
the optimum.
|
|
76
|
+
|
|
77
|
+
gradient_backend: str
|
|
78
|
+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
|
|
79
|
+
|
|
80
|
+
compile_kwargs: dict, optional
|
|
81
|
+
Additional keyword arguments to pass to `pytensor.function`
|
|
82
|
+
|
|
83
|
+
use_grad: bool, optional
|
|
84
|
+
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
|
|
85
|
+
|
|
86
|
+
use_hessp: bool, optional
|
|
87
|
+
If True, pass the hessian vector product to `scipy.optimize.minimize`.
|
|
88
|
+
|
|
89
|
+
use_hess: bool, optional
|
|
90
|
+
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
|
|
91
|
+
computation can be slow and memory-intensive if there are many parameters.
|
|
92
|
+
|
|
93
|
+
progressbar: bool
|
|
94
|
+
Whether or not to show a progress bar during optimization. Default is True.
|
|
95
|
+
|
|
96
|
+
optimizer_kwargs:
|
|
97
|
+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
|
|
98
|
+
that function for details.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
:class:`~arviz.InferenceData`
|
|
103
|
+
The inference data containing the results of the DADVI algorithm.
|
|
104
|
+
|
|
105
|
+
References
|
|
106
|
+
----------
|
|
107
|
+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
|
|
108
|
+
Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
model = pymc.modelcontext(model) if model is None else model
|
|
112
|
+
do_basinhopping = optimizer_method == "basinhopping"
|
|
113
|
+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
114
|
+
|
|
115
|
+
if do_basinhopping:
|
|
116
|
+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
|
|
117
|
+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
|
|
118
|
+
# if one isn't provided.
|
|
119
|
+
|
|
120
|
+
optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
|
|
121
|
+
minimizer_kwargs["method"] = optimizer_method
|
|
122
|
+
|
|
123
|
+
initial_point_dict = model.initial_point()
|
|
124
|
+
initial_point = DictToArrayBijection.map(initial_point_dict)
|
|
125
|
+
n_params = initial_point.data.shape[0]
|
|
126
|
+
|
|
127
|
+
var_params, objective = create_dadvi_graph(
|
|
128
|
+
model,
|
|
129
|
+
n_fixed_draws=n_fixed_draws,
|
|
130
|
+
random_seed=random_seed,
|
|
131
|
+
n_params=n_params,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
135
|
+
optimizer_method, use_grad, use_hess, use_hessp
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
|
|
139
|
+
loss=objective,
|
|
140
|
+
inputs=[var_params],
|
|
141
|
+
initial_point_dict=None,
|
|
142
|
+
use_grad=use_grad,
|
|
143
|
+
use_hessp=use_hessp,
|
|
144
|
+
use_hess=use_hess,
|
|
145
|
+
gradient_backend=gradient_backend,
|
|
146
|
+
compile_kwargs=compile_kwargs,
|
|
147
|
+
inputs_are_flat=True,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
dadvi_initial_point = {
|
|
151
|
+
f"{var_name}_mu": np.zeros_like(value).ravel()
|
|
152
|
+
for var_name, value in initial_point_dict.items()
|
|
153
|
+
}
|
|
154
|
+
dadvi_initial_point.update(
|
|
155
|
+
{
|
|
156
|
+
f"{var_name}_sigma__log": np.zeros_like(value).ravel()
|
|
157
|
+
for var_name, value in initial_point_dict.items()
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
|
|
162
|
+
args = optimizer_kwargs.pop("args", ())
|
|
163
|
+
|
|
164
|
+
if do_basinhopping:
|
|
165
|
+
if "args" not in minimizer_kwargs:
|
|
166
|
+
minimizer_kwargs["args"] = args
|
|
167
|
+
if "hessp" not in minimizer_kwargs:
|
|
168
|
+
minimizer_kwargs["hessp"] = f_hessp
|
|
169
|
+
if "method" not in minimizer_kwargs:
|
|
170
|
+
minimizer_kwargs["method"] = optimizer_method
|
|
171
|
+
|
|
172
|
+
result = basinhopping(
|
|
173
|
+
func=f_fused,
|
|
174
|
+
x0=dadvi_initial_point.data,
|
|
175
|
+
progressbar=progressbar,
|
|
176
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
177
|
+
**optimizer_kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
result = minimize(
|
|
182
|
+
f=f_fused,
|
|
183
|
+
x0=dadvi_initial_point.data,
|
|
184
|
+
args=args,
|
|
185
|
+
method=optimizer_method,
|
|
186
|
+
hessp=f_hessp,
|
|
187
|
+
progressbar=progressbar,
|
|
188
|
+
**optimizer_kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
|
|
192
|
+
|
|
193
|
+
opt_var_params = result.x
|
|
194
|
+
opt_means, opt_log_sds = np.split(opt_var_params, 2)
|
|
195
|
+
|
|
196
|
+
# Make the draws:
|
|
197
|
+
generator = np.random.default_rng(seed=random_seed)
|
|
198
|
+
draws_raw = generator.standard_normal(size=(n_draws, n_params))
|
|
199
|
+
|
|
200
|
+
draws = opt_means + draws_raw * np.exp(opt_log_sds)
|
|
201
|
+
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
|
|
202
|
+
|
|
203
|
+
idata = dadvi_result_to_idata(
|
|
204
|
+
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
|
|
208
|
+
var_name_to_model_var.update(
|
|
209
|
+
{f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
idata = add_optimizer_result_to_inference_data(
|
|
213
|
+
idata=idata,
|
|
214
|
+
result=result,
|
|
215
|
+
method=optimizer_method,
|
|
216
|
+
mu=raveled_optimized,
|
|
217
|
+
model=model,
|
|
218
|
+
var_name_to_model_var=var_name_to_model_var,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
idata = add_data_to_inference_data(
|
|
222
|
+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return idata
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def create_dadvi_graph(
|
|
229
|
+
model: Model,
|
|
230
|
+
n_params: int,
|
|
231
|
+
n_fixed_draws: int = 30,
|
|
232
|
+
random_seed: RandomSeed = None,
|
|
233
|
+
) -> tuple[TensorVariable, TensorVariable]:
|
|
234
|
+
"""
|
|
235
|
+
Sets up the DADVI graph in pytensor and returns it.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
model : pm.Model
|
|
240
|
+
The PyMC model to be fit.
|
|
241
|
+
|
|
242
|
+
n_params: int
|
|
243
|
+
The total number of parameters in the model.
|
|
244
|
+
|
|
245
|
+
n_fixed_draws : int
|
|
246
|
+
The number of fixed draws to use.
|
|
247
|
+
|
|
248
|
+
random_seed: int
|
|
249
|
+
The random seed to use for the fixed draws.
|
|
250
|
+
|
|
251
|
+
Returns
|
|
252
|
+
-------
|
|
253
|
+
Tuple[TensorVariable, TensorVariable]
|
|
254
|
+
A tuple whose first element contains the variational parameters,
|
|
255
|
+
and whose second contains the DADVI objective.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
# Make the fixed draws
|
|
259
|
+
generator = np.random.default_rng(seed=random_seed)
|
|
260
|
+
draws = generator.standard_normal(size=(n_fixed_draws, n_params))
|
|
261
|
+
|
|
262
|
+
inputs = model.continuous_value_vars + model.discrete_value_vars
|
|
263
|
+
initial_point_dict = model.initial_point()
|
|
264
|
+
logp = model.logp()
|
|
265
|
+
|
|
266
|
+
# Graph in terms of a flat input
|
|
267
|
+
[logp], flat_input = join_nonshared_inputs(
|
|
268
|
+
point=initial_point_dict, outputs=[logp], inputs=inputs
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
var_params = pt.vector(name="eta", shape=(2 * n_params,))
|
|
272
|
+
|
|
273
|
+
means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
|
|
274
|
+
|
|
275
|
+
draw_matrix = pt.constant(draws)
|
|
276
|
+
samples = means + pt.exp(log_sds) * draw_matrix
|
|
277
|
+
|
|
278
|
+
logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
|
|
279
|
+
|
|
280
|
+
mean_log_density = pt.mean(logp_vectorized_draws)
|
|
281
|
+
entropy = pt.sum(log_sds)
|
|
282
|
+
|
|
283
|
+
objective = -mean_log_density - entropy
|
|
284
|
+
|
|
285
|
+
return var_params, objective
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def dadvi_result_to_idata(
|
|
289
|
+
unstacked_draws: xarray.Dataset,
|
|
290
|
+
model: Model,
|
|
291
|
+
include_transformed: bool = False,
|
|
292
|
+
progressbar: bool = True,
|
|
293
|
+
):
|
|
294
|
+
"""
|
|
295
|
+
Transforms the unconstrained draws back into the constrained space.
|
|
296
|
+
|
|
297
|
+
Parameters
|
|
298
|
+
----------
|
|
299
|
+
unstacked_draws : xarray.Dataset
|
|
300
|
+
The draws to constrain back into the original space.
|
|
301
|
+
|
|
302
|
+
model : Model
|
|
303
|
+
The PyMC model the variables were derived from.
|
|
304
|
+
|
|
305
|
+
n_draws: int
|
|
306
|
+
The number of draws to return from the variational approximation.
|
|
307
|
+
|
|
308
|
+
include_transformed: bool
|
|
309
|
+
Whether or not to keep the unconstrained variables in the output.
|
|
310
|
+
|
|
311
|
+
progressbar: bool
|
|
312
|
+
Whether or not to show a progress bar during the transformation. Default is True.
|
|
313
|
+
|
|
314
|
+
Returns
|
|
315
|
+
-------
|
|
316
|
+
:class:`~arviz.InferenceData`
|
|
317
|
+
Draws from the original constrained parameters.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
filtered_var_names = model.unobserved_value_vars
|
|
321
|
+
vars_to_sample = list(
|
|
322
|
+
get_default_varnames(filtered_var_names, include_transformed=include_transformed)
|
|
323
|
+
)
|
|
324
|
+
fn = pytensor.function(model.value_vars, vars_to_sample)
|
|
325
|
+
point_func = PointFunc(fn)
|
|
326
|
+
|
|
327
|
+
coords, dims = coords_and_dims_for_inferencedata(model)
|
|
328
|
+
|
|
329
|
+
transformed_result = apply_function_over_dataset(
|
|
330
|
+
point_func,
|
|
331
|
+
unstacked_draws,
|
|
332
|
+
output_var_names=[x.name for x in vars_to_sample],
|
|
333
|
+
coords=coords,
|
|
334
|
+
dims=dims,
|
|
335
|
+
progressbar=progressbar,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
constrained_names = [
|
|
339
|
+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
|
|
340
|
+
]
|
|
341
|
+
all_varnames = [
|
|
342
|
+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
343
|
+
]
|
|
344
|
+
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
|
|
345
|
+
|
|
346
|
+
idata = az.InferenceData(posterior=transformed_result[constrained_names])
|
|
347
|
+
|
|
348
|
+
if unconstrained_names and include_transformed:
|
|
349
|
+
idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
|
|
350
|
+
|
|
351
|
+
return idata
|
pymc_extras/inference/fit.py
CHANGED
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
|
7
7
|
import pymc as pm
|
|
8
8
|
|
|
9
9
|
from better_optimize import basinhopping, minimize
|
|
10
|
-
from better_optimize.constants import
|
|
10
|
+
from better_optimize.constants import minimize_method
|
|
11
11
|
from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
12
12
|
from pymc.initial_point import make_initial_point_fn
|
|
13
13
|
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
@@ -24,40 +24,12 @@ from pymc_extras.inference.laplace_approx.idata import (
|
|
|
24
24
|
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
25
25
|
GradientBackend,
|
|
26
26
|
scipy_optimize_funcs_from_loss,
|
|
27
|
+
set_optimizer_function_defaults,
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
_log = logging.getLogger(__name__)
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
33
|
-
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
34
|
-
|
|
35
|
-
if use_hess and use_hessp:
|
|
36
|
-
_log.warning(
|
|
37
|
-
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
38
|
-
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
39
|
-
'Setting "use_hess" to False.'
|
|
40
|
-
)
|
|
41
|
-
use_hess = False
|
|
42
|
-
|
|
43
|
-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
44
|
-
|
|
45
|
-
if use_hessp is not None and use_hess is None:
|
|
46
|
-
use_hess = not use_hessp
|
|
47
|
-
|
|
48
|
-
elif use_hess is not None and use_hessp is None:
|
|
49
|
-
use_hessp = not use_hess
|
|
50
|
-
|
|
51
|
-
elif use_hessp is None and use_hess is None:
|
|
52
|
-
use_hessp = method_info["uses_hessp"]
|
|
53
|
-
use_hess = method_info["uses_hess"]
|
|
54
|
-
if use_hessp and use_hess:
|
|
55
|
-
# If a method could use either hess or hessp, we default to using hessp
|
|
56
|
-
use_hess = False
|
|
57
|
-
|
|
58
|
-
return use_grad, use_hess, use_hessp
|
|
59
|
-
|
|
60
|
-
|
|
61
33
|
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
62
34
|
"""
|
|
63
35
|
Compute the nearest positive semi-definite matrix to a given matrix.
|
|
@@ -196,8 +168,10 @@ def find_MAP(
|
|
|
196
168
|
jitter_rvs: list[TensorVariable] | None = None,
|
|
197
169
|
progressbar: bool = True,
|
|
198
170
|
include_transformed: bool = True,
|
|
171
|
+
freeze_model: bool = True,
|
|
199
172
|
gradient_backend: GradientBackend = "pytensor",
|
|
200
173
|
compile_kwargs: dict | None = None,
|
|
174
|
+
compute_hessian: bool = False,
|
|
201
175
|
**optimizer_kwargs,
|
|
202
176
|
) -> (
|
|
203
177
|
dict[str, np.ndarray]
|
|
@@ -237,8 +211,16 @@ def find_MAP(
|
|
|
237
211
|
Whether to display a progress bar during optimization. Defaults to True.
|
|
238
212
|
include_transformed: bool, optional
|
|
239
213
|
Whether to include transformed variable values in the returned dictionary. Defaults to True.
|
|
214
|
+
freeze_model: bool, optional
|
|
215
|
+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
|
|
216
|
+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
|
|
217
|
+
True.
|
|
240
218
|
gradient_backend: str, default "pytensor"
|
|
241
219
|
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
|
|
220
|
+
compute_hessian: bool
|
|
221
|
+
If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
|
|
222
|
+
InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
|
|
223
|
+
high-dimensional problems. Defaults to False.
|
|
242
224
|
compile_kwargs: dict, optional
|
|
243
225
|
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
|
|
244
226
|
**optimizer_kwargs
|
|
@@ -252,11 +234,13 @@ def find_MAP(
|
|
|
252
234
|
Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
|
|
253
235
|
latent variables, and optimizer results.
|
|
254
236
|
"""
|
|
255
|
-
model = pm.modelcontext(model) if model is None else model
|
|
256
|
-
frozen_model = freeze_dims_and_data(model)
|
|
257
237
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
238
|
+
model = pm.modelcontext(model) if model is None else model
|
|
239
|
+
|
|
240
|
+
if freeze_model:
|
|
241
|
+
model = freeze_dims_and_data(model)
|
|
258
242
|
|
|
259
|
-
initial_params = _make_initial_point(
|
|
243
|
+
initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
|
|
260
244
|
|
|
261
245
|
do_basinhopping = method == "basinhopping"
|
|
262
246
|
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
@@ -274,8 +258,8 @@ def find_MAP(
|
|
|
274
258
|
)
|
|
275
259
|
|
|
276
260
|
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
|
|
277
|
-
loss=-
|
|
278
|
-
inputs=
|
|
261
|
+
loss=-model.logp(),
|
|
262
|
+
inputs=model.continuous_value_vars + model.discrete_value_vars,
|
|
279
263
|
initial_point_dict=DictToArrayBijection.rmap(initial_params),
|
|
280
264
|
use_grad=use_grad,
|
|
281
265
|
use_hess=use_hess,
|
|
@@ -316,14 +300,17 @@ def find_MAP(
|
|
|
316
300
|
**optimizer_kwargs,
|
|
317
301
|
)
|
|
318
302
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
303
|
+
if compute_hessian:
|
|
304
|
+
H_inv = _compute_inverse_hessian(
|
|
305
|
+
optimizer_result=optimizer_result,
|
|
306
|
+
optimal_point=None,
|
|
307
|
+
f_fused=f_fused,
|
|
308
|
+
f_hessp=f_hessp,
|
|
309
|
+
use_hess=use_hess,
|
|
310
|
+
method=method,
|
|
311
|
+
)
|
|
312
|
+
else:
|
|
313
|
+
H_inv = None
|
|
327
314
|
|
|
328
315
|
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
|
|
329
316
|
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
@@ -336,12 +323,10 @@ def find_MAP(
|
|
|
336
323
|
}
|
|
337
324
|
|
|
338
325
|
idata = map_results_to_inference_data(
|
|
339
|
-
map_point=optimized_point, model=
|
|
326
|
+
map_point=optimized_point, model=model, include_transformed=include_transformed
|
|
340
327
|
)
|
|
341
328
|
|
|
342
|
-
idata = add_fit_to_inference_data(
|
|
343
|
-
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
|
|
344
|
-
)
|
|
329
|
+
idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
|
|
345
330
|
|
|
346
331
|
idata = add_optimizer_result_to_inference_data(
|
|
347
332
|
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
|