pymc-extras 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/inference/__init__.py +2 -2
- pymc_extras/inference/fit.py +1 -1
- pymc_extras/inference/laplace_approx/__init__.py +0 -0
- pymc_extras/inference/laplace_approx/find_map.py +354 -0
- pymc_extras/inference/laplace_approx/idata.py +393 -0
- pymc_extras/inference/laplace_approx/laplace.py +453 -0
- pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
- pymc_extras/inference/pathfinder/pathfinder.py +3 -4
- pymc_extras/linearmodel.py +3 -1
- pymc_extras/model/marginal/graph_analysis.py +4 -0
- pymc_extras/prior.py +38 -6
- pymc_extras/statespace/core/statespace.py +78 -52
- pymc_extras/statespace/filters/kalman_smoother.py +1 -1
- pymc_extras/statespace/models/structural/__init__.py +21 -0
- pymc_extras/statespace/models/structural/components/__init__.py +0 -0
- pymc_extras/statespace/models/structural/components/autoregressive.py +188 -0
- pymc_extras/statespace/models/structural/components/cycle.py +305 -0
- pymc_extras/statespace/models/structural/components/level_trend.py +257 -0
- pymc_extras/statespace/models/structural/components/measurement_error.py +137 -0
- pymc_extras/statespace/models/structural/components/regression.py +228 -0
- pymc_extras/statespace/models/structural/components/seasonality.py +445 -0
- pymc_extras/statespace/models/structural/core.py +900 -0
- pymc_extras/statespace/models/structural/utils.py +16 -0
- pymc_extras/statespace/models/utilities.py +285 -0
- pymc_extras/statespace/utils/constants.py +4 -4
- pymc_extras/statespace/utils/data_tools.py +3 -2
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/METADATA +6 -6
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/RECORD +30 -18
- pymc_extras/inference/find_map.py +0 -496
- pymc_extras/inference/laplace.py +0 -583
- pymc_extras/statespace/models/structural.py +0 -1679
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
# Copyright 2024 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from functools import partial
|
|
20
|
+
from typing import Literal
|
|
21
|
+
from typing import cast as type_cast
|
|
22
|
+
|
|
23
|
+
import arviz as az
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pymc as pm
|
|
26
|
+
import pytensor
|
|
27
|
+
import pytensor.tensor as pt
|
|
28
|
+
import xarray as xr
|
|
29
|
+
|
|
30
|
+
from better_optimize.constants import minimize_method
|
|
31
|
+
from numpy.typing import ArrayLike
|
|
32
|
+
from pymc.blocking import DictToArrayBijection
|
|
33
|
+
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
34
|
+
from pymc.pytensorf import join_nonshared_inputs
|
|
35
|
+
from pymc.util import get_default_varnames
|
|
36
|
+
from pytensor.graph import vectorize_graph
|
|
37
|
+
from pytensor.tensor import TensorVariable
|
|
38
|
+
from pytensor.tensor.optimize import minimize
|
|
39
|
+
from pytensor.tensor.type import Variable
|
|
40
|
+
|
|
41
|
+
from pymc_extras.inference.laplace_approx.find_map import (
|
|
42
|
+
_compute_inverse_hessian,
|
|
43
|
+
_make_initial_point,
|
|
44
|
+
find_MAP,
|
|
45
|
+
)
|
|
46
|
+
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
47
|
+
GradientBackend,
|
|
48
|
+
scipy_optimize_funcs_from_loss,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
_log = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_conditional_gaussian_approximation(
|
|
55
|
+
x: TensorVariable,
|
|
56
|
+
Q: TensorVariable | ArrayLike,
|
|
57
|
+
mu: TensorVariable | ArrayLike,
|
|
58
|
+
args: list[TensorVariable] | None = None,
|
|
59
|
+
model: pm.Model | None = None,
|
|
60
|
+
method: minimize_method = "BFGS",
|
|
61
|
+
use_jac: bool = True,
|
|
62
|
+
use_hess: bool = False,
|
|
63
|
+
optimizer_kwargs: dict | None = None,
|
|
64
|
+
) -> Callable:
|
|
65
|
+
"""
|
|
66
|
+
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
|
|
67
|
+
|
|
68
|
+
That is:
|
|
69
|
+
y | x, sigma ~ N(Ax, sigma^2 W)
|
|
70
|
+
x | params ~ N(mu, Q(params)^-1)
|
|
71
|
+
|
|
72
|
+
We seek to estimate log(p(x | y, params)):
|
|
73
|
+
|
|
74
|
+
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
|
|
75
|
+
|
|
76
|
+
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
|
|
77
|
+
|
|
78
|
+
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
|
|
79
|
+
|
|
80
|
+
Thus:
|
|
81
|
+
|
|
82
|
+
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
|
|
83
|
+
|
|
84
|
+
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
x: TensorVariable
|
|
89
|
+
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
|
|
90
|
+
Q: TensorVariable | ArrayLike
|
|
91
|
+
The precision matrix of the latent field x.
|
|
92
|
+
mu: TensorVariable | ArrayLike
|
|
93
|
+
The mean of the latent field x.
|
|
94
|
+
args: list[TensorVariable]
|
|
95
|
+
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
|
|
96
|
+
model: Model
|
|
97
|
+
PyMC model to use.
|
|
98
|
+
method: minimize_method
|
|
99
|
+
Which minimization algorithm to use.
|
|
100
|
+
use_jac: bool
|
|
101
|
+
If true, the minimizer will compute the gradient of log(p(x | y, params)).
|
|
102
|
+
use_hess: bool
|
|
103
|
+
If true, the minimizer will compute the Hessian log(p(x | y, params)).
|
|
104
|
+
optimizer_kwargs: dict
|
|
105
|
+
Kwargs to pass to scipy.optimize.minimize.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
f: Callable
|
|
110
|
+
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
|
|
111
|
+
"""
|
|
112
|
+
model = pm.modelcontext(model)
|
|
113
|
+
|
|
114
|
+
if args is None:
|
|
115
|
+
args = model.continuous_value_vars + model.discrete_value_vars
|
|
116
|
+
|
|
117
|
+
# f = log(p(y | x, params))
|
|
118
|
+
f_x = model.logp()
|
|
119
|
+
jac = pytensor.gradient.grad(f_x, x)
|
|
120
|
+
hess = pytensor.gradient.jacobian(jac.flatten(), x)
|
|
121
|
+
|
|
122
|
+
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
|
|
123
|
+
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
|
|
124
|
+
|
|
125
|
+
# Maximize log(p(x | y, params)) wrt x to find mode x0
|
|
126
|
+
x0, _ = minimize(
|
|
127
|
+
objective=-log_x_posterior,
|
|
128
|
+
x=x,
|
|
129
|
+
method=method,
|
|
130
|
+
jac=use_jac,
|
|
131
|
+
hess=use_hess,
|
|
132
|
+
optimizer_kwargs=optimizer_kwargs,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# require f'(x0) and f''(x0) for Laplace approx
|
|
136
|
+
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
|
|
137
|
+
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
|
|
138
|
+
|
|
139
|
+
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
|
|
140
|
+
_, logdetQ = pt.nlinalg.slogdet(Q)
|
|
141
|
+
conditional_gaussian_approx = (
|
|
142
|
+
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
|
|
146
|
+
# far from the mode x0 or in a neighbourhood which results in poor convergence.
|
|
147
|
+
return pytensor.function(args, [x0, conditional_gaussian_approx])
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _unconstrained_vector_to_constrained_rvs(model):
|
|
151
|
+
outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
152
|
+
constrained_names = [
|
|
153
|
+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
|
|
154
|
+
]
|
|
155
|
+
names = [x.name for x in outputs]
|
|
156
|
+
|
|
157
|
+
unconstrained_names = [name for name in names if name not in constrained_names]
|
|
158
|
+
|
|
159
|
+
new_outputs, unconstrained_vector = join_nonshared_inputs(
|
|
160
|
+
model.initial_point(),
|
|
161
|
+
inputs=model.value_vars,
|
|
162
|
+
outputs=outputs,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
|
|
166
|
+
value_rvs = [x for x in new_outputs if x not in constrained_rvs]
|
|
167
|
+
|
|
168
|
+
unconstrained_vector.name = "unconstrained_vector"
|
|
169
|
+
|
|
170
|
+
# Redo the names list to ensure it is sorted to match the return order
|
|
171
|
+
names = [*constrained_names, *unconstrained_names]
|
|
172
|
+
|
|
173
|
+
return names, constrained_rvs, value_rvs, unconstrained_vector
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def model_to_laplace_approx(
|
|
177
|
+
model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
|
|
178
|
+
):
|
|
179
|
+
initial_point = model.initial_point()
|
|
180
|
+
raveled_vars = DictToArrayBijection.map(initial_point)
|
|
181
|
+
raveled_shape = raveled_vars.data.shape[0]
|
|
182
|
+
|
|
183
|
+
# temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
|
|
184
|
+
# so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
|
|
185
|
+
names, constrained_rvs, value_rvs, unconstrained_vector = (
|
|
186
|
+
_unconstrained_vector_to_constrained_rvs(model)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
coords = model.coords | {
|
|
190
|
+
"temp_chain": np.arange(chains),
|
|
191
|
+
"temp_draw": np.arange(draws),
|
|
192
|
+
"unpacked_variable_names": unpacked_variable_names,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
with pm.Model(coords=coords, model=None) as laplace_model:
|
|
196
|
+
mu = pm.Flat("mean_vector", shape=(raveled_shape,))
|
|
197
|
+
cov = pm.Flat("covariance_matrix", shape=(raveled_shape, raveled_shape))
|
|
198
|
+
laplace_approximation = pm.MvNormal(
|
|
199
|
+
"laplace_approximation",
|
|
200
|
+
mu=mu,
|
|
201
|
+
cov=cov,
|
|
202
|
+
dims=["temp_chain", "temp_draw", "unpacked_variable_names"],
|
|
203
|
+
method="svd",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
cast_to_var = partial(type_cast, Variable)
|
|
207
|
+
batched_rvs = vectorize_graph(
|
|
208
|
+
type_cast(list[Variable], constrained_rvs),
|
|
209
|
+
replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
for name, batched_rv in zip(names, batched_rvs):
|
|
213
|
+
batch_dims = ("temp_chain", "temp_draw")
|
|
214
|
+
if batched_rv.ndim == 2:
|
|
215
|
+
dims = batch_dims
|
|
216
|
+
elif name in model.named_vars_to_dims:
|
|
217
|
+
dims = (*batch_dims, *model.named_vars_to_dims[name])
|
|
218
|
+
else:
|
|
219
|
+
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
|
|
220
|
+
initval = initial_point.get(name, None)
|
|
221
|
+
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
|
|
222
|
+
laplace_model.add_coords(
|
|
223
|
+
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
pm.Deterministic(name, batched_rv, dims=dims)
|
|
227
|
+
|
|
228
|
+
return laplace_model
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
|
|
232
|
+
"""
|
|
233
|
+
The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
|
|
234
|
+
in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
|
|
235
|
+
single vector, it's not easy to work with.
|
|
236
|
+
|
|
237
|
+
This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
|
|
238
|
+
coordinates, where possible.
|
|
239
|
+
"""
|
|
240
|
+
initial_point = DictToArrayBijection.map(model.initial_point())
|
|
241
|
+
|
|
242
|
+
cursor = 0
|
|
243
|
+
unstacked_laplace_draws = {}
|
|
244
|
+
coords = model.coords | {"chain": range(chains), "draw": range(draws)}
|
|
245
|
+
|
|
246
|
+
# There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
|
|
247
|
+
# simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
|
|
248
|
+
# add an arviz-style default dim and label.
|
|
249
|
+
for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
|
|
250
|
+
rv_dims = []
|
|
251
|
+
for i, dim in enumerate(
|
|
252
|
+
model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
|
|
253
|
+
):
|
|
254
|
+
if coords.get(dim) and shape[i] == len(coords[dim]):
|
|
255
|
+
rv_dims.append(dim)
|
|
256
|
+
else:
|
|
257
|
+
rv_dims.append(f"{name}_dim_{i}")
|
|
258
|
+
coords[f"{name}_dim_{i}"] = np.arange(shape[i])
|
|
259
|
+
|
|
260
|
+
dims = ("chain", "draw", *rv_dims)
|
|
261
|
+
|
|
262
|
+
values = (
|
|
263
|
+
laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
|
|
264
|
+
)
|
|
265
|
+
unstacked_laplace_draws[name] = xr.DataArray(
|
|
266
|
+
values, dims=dims, coords={dim: list(coords[dim]) for dim in dims}
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
cursor += size
|
|
270
|
+
|
|
271
|
+
unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
|
|
272
|
+
|
|
273
|
+
return unstacked_laplace_draws
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def fit_laplace(
|
|
277
|
+
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
|
|
278
|
+
*,
|
|
279
|
+
model: pm.Model | None = None,
|
|
280
|
+
use_grad: bool | None = None,
|
|
281
|
+
use_hessp: bool | None = None,
|
|
282
|
+
use_hess: bool | None = None,
|
|
283
|
+
initvals: dict | None = None,
|
|
284
|
+
random_seed: int | np.random.Generator | None = None,
|
|
285
|
+
jitter_rvs: list[pt.TensorVariable] | None = None,
|
|
286
|
+
progressbar: bool = True,
|
|
287
|
+
include_transformed: bool = True,
|
|
288
|
+
gradient_backend: GradientBackend = "pytensor",
|
|
289
|
+
chains: int = 2,
|
|
290
|
+
draws: int = 500,
|
|
291
|
+
optimizer_kwargs: dict | None = None,
|
|
292
|
+
compile_kwargs: dict | None = None,
|
|
293
|
+
) -> az.InferenceData:
|
|
294
|
+
"""
|
|
295
|
+
Create a Laplace (quadratic) approximation for a posterior distribution.
|
|
296
|
+
|
|
297
|
+
This function generates a Laplace approximation for a given posterior distribution using a specified
|
|
298
|
+
number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
|
|
299
|
+
that can be used for further analysis.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
model : pm.Model
|
|
304
|
+
The PyMC model to be fit. If None, the current model context is used.
|
|
305
|
+
optimize_method : str
|
|
306
|
+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
307
|
+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
308
|
+
|
|
309
|
+
See scipy.optimize.minimize documentation for details.
|
|
310
|
+
use_grad : bool | None, optional
|
|
311
|
+
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
312
|
+
the ``method``.
|
|
313
|
+
use_hessp : bool | None, optional
|
|
314
|
+
Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
|
|
315
|
+
the ``method``.
|
|
316
|
+
use_hess : bool | None, optional
|
|
317
|
+
Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
|
|
318
|
+
the ``method``.
|
|
319
|
+
initvals : None | dict, optional
|
|
320
|
+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
321
|
+
If None, the model's default initial values are used.
|
|
322
|
+
random_seed : None | int | np.random.Generator, optional
|
|
323
|
+
Seed for the random number generator or a numpy Generator for reproducibility
|
|
324
|
+
jitter_rvs : list of TensorVariables, optional
|
|
325
|
+
Variables whose initial values should be jittered. If None, all variables are jittered.
|
|
326
|
+
progressbar : bool, optional
|
|
327
|
+
Whether to display a progress bar during optimization. Defaults to True.
|
|
328
|
+
include_transformed: bool, default True
|
|
329
|
+
Whether to include transformed variables in the output. If True, transformed variables will be included in the
|
|
330
|
+
output InferenceData object. If False, only the original variables will be included.
|
|
331
|
+
gradient_backend: str, default "pytensor"
|
|
332
|
+
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
333
|
+
chains: int, default: 2
|
|
334
|
+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
335
|
+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
336
|
+
compatible with the ArviZ library.
|
|
337
|
+
draws: int, default: 500
|
|
338
|
+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
|
|
339
|
+
optimizer_kwargs
|
|
340
|
+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
341
|
+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
342
|
+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
343
|
+
compile_kwargs: dict, optional
|
|
344
|
+
Additional keyword arguments to pass to pytensor.function.
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
-------
|
|
348
|
+
:class:`~arviz.InferenceData`
|
|
349
|
+
An InferenceData object containing the approximated posterior samples.
|
|
350
|
+
|
|
351
|
+
Examples
|
|
352
|
+
--------
|
|
353
|
+
>>> from pymc_extras.inference import fit_laplace
|
|
354
|
+
>>> import numpy as np
|
|
355
|
+
>>> import pymc as pm
|
|
356
|
+
>>> import arviz as az
|
|
357
|
+
>>> y = np.array([2642, 3503, 4358]*10)
|
|
358
|
+
>>> with pm.Model() as m:
|
|
359
|
+
>>> logsigma = pm.Uniform("logsigma", 1, 100)
|
|
360
|
+
>>> mu = pm.Uniform("mu", -10000, 10000)
|
|
361
|
+
>>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
|
|
362
|
+
>>> idata = fit_laplace()
|
|
363
|
+
|
|
364
|
+
Notes
|
|
365
|
+
-----
|
|
366
|
+
This method of approximation may not be suitable for all types of posterior distributions,
|
|
367
|
+
especially those with significant skewness or multimodality.
|
|
368
|
+
|
|
369
|
+
See Also
|
|
370
|
+
--------
|
|
371
|
+
fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m)
|
|
372
|
+
will forward the call to 'fit_laplace'.
|
|
373
|
+
|
|
374
|
+
"""
|
|
375
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
376
|
+
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
|
|
377
|
+
model = pm.modelcontext(model) if model is None else model
|
|
378
|
+
|
|
379
|
+
idata = find_MAP(
|
|
380
|
+
method=optimize_method,
|
|
381
|
+
model=model,
|
|
382
|
+
use_grad=use_grad,
|
|
383
|
+
use_hessp=use_hessp,
|
|
384
|
+
use_hess=use_hess,
|
|
385
|
+
initvals=initvals,
|
|
386
|
+
random_seed=random_seed,
|
|
387
|
+
jitter_rvs=jitter_rvs,
|
|
388
|
+
progressbar=progressbar,
|
|
389
|
+
include_transformed=include_transformed,
|
|
390
|
+
gradient_backend=gradient_backend,
|
|
391
|
+
compile_kwargs=compile_kwargs,
|
|
392
|
+
**optimizer_kwargs,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
396
|
+
|
|
397
|
+
if "covariance_matrix" not in idata.fit:
|
|
398
|
+
# The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
|
|
399
|
+
# we have to go back and compute the Hessian at the MAP point now.
|
|
400
|
+
frozen_model = freeze_dims_and_data(model)
|
|
401
|
+
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
|
|
402
|
+
|
|
403
|
+
_, f_hessp = scipy_optimize_funcs_from_loss(
|
|
404
|
+
loss=-frozen_model.logp(jacobian=False),
|
|
405
|
+
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
|
|
406
|
+
initial_point_dict=DictToArrayBijection.rmap(initial_params),
|
|
407
|
+
use_grad=False,
|
|
408
|
+
use_hess=False,
|
|
409
|
+
use_hessp=True,
|
|
410
|
+
gradient_backend=gradient_backend,
|
|
411
|
+
compile_kwargs=compile_kwargs,
|
|
412
|
+
)
|
|
413
|
+
H_inv = _compute_inverse_hessian(
|
|
414
|
+
optimizer_result=None,
|
|
415
|
+
optimal_point=idata.fit.mean_vector.values,
|
|
416
|
+
f_fused=None,
|
|
417
|
+
f_hessp=f_hessp,
|
|
418
|
+
use_hess=False,
|
|
419
|
+
method=optimize_method,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
idata.fit["covariance_matrix"] = xr.DataArray(
|
|
423
|
+
H_inv,
|
|
424
|
+
dims=("rows", "columns"),
|
|
425
|
+
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model:
|
|
429
|
+
new_posterior = (
|
|
430
|
+
pm.sample_posterior_predictive(
|
|
431
|
+
idata.fit.expand_dims(chain=[0], draw=[0]),
|
|
432
|
+
extend_inferencedata=False,
|
|
433
|
+
random_seed=random_seed,
|
|
434
|
+
var_names=[
|
|
435
|
+
"laplace_approximation",
|
|
436
|
+
*[x.name for x in laplace_model.deterministics],
|
|
437
|
+
],
|
|
438
|
+
)
|
|
439
|
+
.posterior_predictive.squeeze(["chain", "draw"])
|
|
440
|
+
.drop_vars(["chain", "draw"])
|
|
441
|
+
.rename({"temp_chain": "chain", "temp_draw": "draw"})
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
if include_transformed:
|
|
445
|
+
idata.unconstrained_posterior = unstack_laplace_draws(
|
|
446
|
+
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
idata.posterior = new_posterior.drop_vars(
|
|
450
|
+
["laplace_approximation", "unpacked_variable_names"]
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
return idata
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from importlib.util import find_spec
|
|
3
|
+
from typing import Literal, get_args
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pymc as pm
|
|
7
|
+
import pytensor
|
|
8
|
+
|
|
9
|
+
from pymc import join_nonshared_inputs
|
|
10
|
+
from pytensor import tensor as pt
|
|
11
|
+
from pytensor.compile import Function
|
|
12
|
+
from pytensor.tensor import TensorVariable
|
|
13
|
+
|
|
14
|
+
GradientBackend = Literal["pytensor", "jax"]
|
|
15
|
+
VALID_BACKENDS = get_args(GradientBackend)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _compile_grad_and_hess_to_jax(
|
|
19
|
+
f_fused: Function, use_hess: bool, use_hessp: bool
|
|
20
|
+
) -> tuple[Callable | None, Callable | None]:
|
|
21
|
+
"""
|
|
22
|
+
Compile loss function gradients using JAX.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
f_fused: Function
|
|
27
|
+
The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
|
|
28
|
+
compiled with mode="JAX".
|
|
29
|
+
use_hess: bool
|
|
30
|
+
Whether to compile a function to compute the hessian of the loss function.
|
|
31
|
+
use_hessp: bool
|
|
32
|
+
Whether to compile a function to compute the hessian-vector product of the loss function.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
f_fused: Callable
|
|
37
|
+
The compiled loss function and gradient function, which may also compute the hessian if requested.
|
|
38
|
+
f_hessp: Callable | None
|
|
39
|
+
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
40
|
+
"""
|
|
41
|
+
import jax
|
|
42
|
+
|
|
43
|
+
f_hessp = None
|
|
44
|
+
|
|
45
|
+
orig_loss_fn = f_fused.vm.jit_fn
|
|
46
|
+
|
|
47
|
+
if use_hess:
|
|
48
|
+
|
|
49
|
+
@jax.jit
|
|
50
|
+
def loss_fn_fused(x):
|
|
51
|
+
loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
|
|
52
|
+
hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x)
|
|
53
|
+
return *loss_and_grad, hess
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
@jax.jit
|
|
58
|
+
def loss_fn_fused(x):
|
|
59
|
+
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
|
|
60
|
+
|
|
61
|
+
if use_hessp:
|
|
62
|
+
|
|
63
|
+
def f_hessp_jax(x, p):
|
|
64
|
+
y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,))
|
|
65
|
+
return jax.numpy.stack(u)
|
|
66
|
+
|
|
67
|
+
f_hessp = jax.jit(f_hessp_jax)
|
|
68
|
+
|
|
69
|
+
return loss_fn_fused, f_hessp
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _compile_functions_for_scipy_optimize(
|
|
73
|
+
loss: TensorVariable,
|
|
74
|
+
inputs: list[TensorVariable],
|
|
75
|
+
compute_grad: bool,
|
|
76
|
+
compute_hess: bool,
|
|
77
|
+
compute_hessp: bool,
|
|
78
|
+
compile_kwargs: dict | None = None,
|
|
79
|
+
) -> list[Function] | list[Function, Function | None, Function | None]:
|
|
80
|
+
"""
|
|
81
|
+
Compile loss functions for use with scipy.optimize.minimize.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
loss: TensorVariable
|
|
86
|
+
The loss function to compile.
|
|
87
|
+
inputs: list[TensorVariable]
|
|
88
|
+
A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines
|
|
89
|
+
expect the function signature to be f(x, *args), where x is a 1D array of parameters.
|
|
90
|
+
compute_grad: bool
|
|
91
|
+
Whether to compile a function that computes the gradients of the loss function.
|
|
92
|
+
compute_hess: bool
|
|
93
|
+
Whether to compile a function that computes the Hessian of the loss function.
|
|
94
|
+
compute_hessp: bool
|
|
95
|
+
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
96
|
+
compile_kwargs: dict, optional
|
|
97
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
f_fused: Function
|
|
102
|
+
The compiled loss function, which may also include gradients and hessian if requested.
|
|
103
|
+
f_hessp: Function | None
|
|
104
|
+
The compiled hessian-vector product function, or None if compute_hessp is False.
|
|
105
|
+
"""
|
|
106
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
107
|
+
|
|
108
|
+
loss = pm.pytensorf.rewrite_pregrad(loss)
|
|
109
|
+
f_hessp = None
|
|
110
|
+
|
|
111
|
+
# In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent
|
|
112
|
+
# with the case where we also compute gradients, hessians, or hessian-vector products.
|
|
113
|
+
if not (compute_grad or compute_hess or compute_hessp):
|
|
114
|
+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
|
|
115
|
+
return [f_loss]
|
|
116
|
+
|
|
117
|
+
# Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single
|
|
118
|
+
# fused function and return it. If the user also wants the hessian, the fused function will return the loss,
|
|
119
|
+
# gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss
|
|
120
|
+
# and gradients, and a separate function for the hessian-vector product.
|
|
121
|
+
|
|
122
|
+
if compute_hessp:
|
|
123
|
+
# Handle this first, since it can be compiled alone.
|
|
124
|
+
p = pt.tensor("p", shape=inputs[0].type.shape)
|
|
125
|
+
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
|
|
126
|
+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
|
|
127
|
+
|
|
128
|
+
outputs = [loss]
|
|
129
|
+
|
|
130
|
+
if compute_grad:
|
|
131
|
+
grads = pytensor.gradient.grad(loss, inputs)
|
|
132
|
+
grad = pt.concatenate([grad.ravel() for grad in grads])
|
|
133
|
+
outputs.append(grad)
|
|
134
|
+
|
|
135
|
+
if compute_hess:
|
|
136
|
+
hess = pytensor.gradient.jacobian(grad, inputs)[0]
|
|
137
|
+
outputs.append(hess)
|
|
138
|
+
|
|
139
|
+
f_fused = pm.compile(inputs, outputs, **compile_kwargs)
|
|
140
|
+
|
|
141
|
+
return [f_fused, f_hessp]
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def scipy_optimize_funcs_from_loss(
|
|
145
|
+
loss: TensorVariable,
|
|
146
|
+
inputs: list[TensorVariable],
|
|
147
|
+
initial_point_dict: dict[str, np.ndarray | float | int],
|
|
148
|
+
use_grad: bool,
|
|
149
|
+
use_hess: bool,
|
|
150
|
+
use_hessp: bool,
|
|
151
|
+
gradient_backend: GradientBackend = "pytensor",
|
|
152
|
+
compile_kwargs: dict | None = None,
|
|
153
|
+
) -> tuple[Callable, ...]:
|
|
154
|
+
"""
|
|
155
|
+
Compile loss functions for use with scipy.optimize.minimize.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
loss: TensorVariable
|
|
160
|
+
The loss function to compile.
|
|
161
|
+
inputs: list[TensorVariable]
|
|
162
|
+
The input variables to the loss function.
|
|
163
|
+
initial_point_dict: dict[str, np.ndarray | float | int]
|
|
164
|
+
Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables.
|
|
165
|
+
use_grad: bool
|
|
166
|
+
Whether to compile a function that computes the gradients of the loss function.
|
|
167
|
+
use_hess: bool
|
|
168
|
+
Whether to compile a function that computes the Hessian of the loss function.
|
|
169
|
+
use_hessp: bool
|
|
170
|
+
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
171
|
+
gradient_backend: str, default "pytensor"
|
|
172
|
+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
|
|
173
|
+
compile_kwargs:
|
|
174
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
f_fused: Callable
|
|
179
|
+
The compiled loss function, which may also include gradients and hessian if requested.
|
|
180
|
+
f_hessp: Callable | None
|
|
181
|
+
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
185
|
+
|
|
186
|
+
if use_hess and not use_grad:
|
|
187
|
+
raise ValueError("Cannot compute hessian without also computing the gradient")
|
|
188
|
+
|
|
189
|
+
if gradient_backend not in VALID_BACKENDS:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
use_jax_gradients = (gradient_backend == "jax") and use_grad
|
|
195
|
+
if use_jax_gradients and not find_spec("jax"):
|
|
196
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
197
|
+
|
|
198
|
+
mode = compile_kwargs.get("mode", None)
|
|
199
|
+
if mode is None and use_jax_gradients:
|
|
200
|
+
compile_kwargs["mode"] = "JAX"
|
|
201
|
+
elif mode != "JAX" and use_jax_gradients:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"'
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if not isinstance(inputs, list):
|
|
207
|
+
inputs = [inputs]
|
|
208
|
+
|
|
209
|
+
[loss], flat_input = join_nonshared_inputs(
|
|
210
|
+
point=initial_point_dict, outputs=[loss], inputs=inputs
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
|
|
214
|
+
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
|
|
215
|
+
# away.
|
|
216
|
+
if use_jax_gradients:
|
|
217
|
+
from pymc.sampling.jax import _replace_shared_variables
|
|
218
|
+
|
|
219
|
+
[loss] = _replace_shared_variables([loss])
|
|
220
|
+
|
|
221
|
+
compute_grad = use_grad and not use_jax_gradients
|
|
222
|
+
compute_hess = use_hess and not use_jax_gradients
|
|
223
|
+
compute_hessp = use_hessp and not use_jax_gradients
|
|
224
|
+
|
|
225
|
+
funcs = _compile_functions_for_scipy_optimize(
|
|
226
|
+
loss=loss,
|
|
227
|
+
inputs=[flat_input],
|
|
228
|
+
compute_grad=compute_grad,
|
|
229
|
+
compute_hess=compute_hess,
|
|
230
|
+
compute_hessp=compute_hessp,
|
|
231
|
+
compile_kwargs=compile_kwargs,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients,
|
|
235
|
+
# or the loss function with gradients and hessian.
|
|
236
|
+
f_fused = funcs.pop(0)
|
|
237
|
+
f_hessp = funcs.pop(0) if compute_hessp else None
|
|
238
|
+
|
|
239
|
+
if use_jax_gradients:
|
|
240
|
+
f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp)
|
|
241
|
+
|
|
242
|
+
return f_fused, f_hessp
|