pymc-extras 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/inference/__init__.py +2 -2
- pymc_extras/inference/fit.py +1 -1
- pymc_extras/inference/laplace_approx/__init__.py +0 -0
- pymc_extras/inference/laplace_approx/find_map.py +354 -0
- pymc_extras/inference/laplace_approx/idata.py +393 -0
- pymc_extras/inference/laplace_approx/laplace.py +453 -0
- pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
- pymc_extras/inference/pathfinder/pathfinder.py +3 -4
- pymc_extras/linearmodel.py +3 -1
- pymc_extras/model/marginal/graph_analysis.py +4 -0
- pymc_extras/prior.py +38 -6
- pymc_extras/statespace/core/statespace.py +78 -52
- pymc_extras/statespace/filters/kalman_smoother.py +1 -1
- pymc_extras/statespace/models/structural/__init__.py +21 -0
- pymc_extras/statespace/models/structural/components/__init__.py +0 -0
- pymc_extras/statespace/models/structural/components/autoregressive.py +188 -0
- pymc_extras/statespace/models/structural/components/cycle.py +305 -0
- pymc_extras/statespace/models/structural/components/level_trend.py +257 -0
- pymc_extras/statespace/models/structural/components/measurement_error.py +137 -0
- pymc_extras/statespace/models/structural/components/regression.py +228 -0
- pymc_extras/statespace/models/structural/components/seasonality.py +445 -0
- pymc_extras/statespace/models/structural/core.py +900 -0
- pymc_extras/statespace/models/structural/utils.py +16 -0
- pymc_extras/statespace/models/utilities.py +285 -0
- pymc_extras/statespace/utils/constants.py +4 -4
- pymc_extras/statespace/utils/data_tools.py +3 -2
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/METADATA +6 -6
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/RECORD +30 -18
- pymc_extras/inference/find_map.py +0 -496
- pymc_extras/inference/laplace.py +0 -583
- pymc_extras/statespace/models/structural.py +0 -1679
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/licenses/LICENSE +0 -0
pymc_extras/inference/laplace.py
DELETED
|
@@ -1,583 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The PyMC Developers
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
import logging
|
|
17
|
-
|
|
18
|
-
from functools import reduce
|
|
19
|
-
from importlib.util import find_spec
|
|
20
|
-
from itertools import product
|
|
21
|
-
from typing import Literal
|
|
22
|
-
|
|
23
|
-
import arviz as az
|
|
24
|
-
import numpy as np
|
|
25
|
-
import pymc as pm
|
|
26
|
-
import pytensor
|
|
27
|
-
import pytensor.tensor as pt
|
|
28
|
-
import xarray as xr
|
|
29
|
-
|
|
30
|
-
from arviz import dict_to_dataset
|
|
31
|
-
from better_optimize.constants import minimize_method
|
|
32
|
-
from pymc import DictToArrayBijection
|
|
33
|
-
from pymc.backends.arviz import (
|
|
34
|
-
coords_and_dims_for_inferencedata,
|
|
35
|
-
find_constants,
|
|
36
|
-
find_observations,
|
|
37
|
-
)
|
|
38
|
-
from pymc.blocking import RaveledVars
|
|
39
|
-
from pymc.model.transform.conditioning import remove_value_transforms
|
|
40
|
-
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
41
|
-
from pymc.util import get_default_varnames
|
|
42
|
-
from scipy import stats
|
|
43
|
-
|
|
44
|
-
from pymc_extras.inference.find_map import (
|
|
45
|
-
GradientBackend,
|
|
46
|
-
_unconstrained_vector_to_constrained_rvs,
|
|
47
|
-
find_MAP,
|
|
48
|
-
get_nearest_psd,
|
|
49
|
-
scipy_optimize_funcs_from_loss,
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
_log = logging.getLogger(__name__)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def laplace_draws_to_inferencedata(
|
|
56
|
-
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
|
|
57
|
-
) -> az.InferenceData:
|
|
58
|
-
"""
|
|
59
|
-
Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object.
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
Parameters
|
|
63
|
-
----------
|
|
64
|
-
posterior_draws: list of np.ndarray
|
|
65
|
-
A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where
|
|
66
|
-
shape is the shape of the variable in the posterior.
|
|
67
|
-
model: Model, optional
|
|
68
|
-
A PyMC model. If None, the model is taken from the current model context.
|
|
69
|
-
|
|
70
|
-
Returns
|
|
71
|
-
-------
|
|
72
|
-
idata: az.InferenceData
|
|
73
|
-
An InferenceData object containing the approximated posterior samples
|
|
74
|
-
"""
|
|
75
|
-
model = pm.modelcontext(model)
|
|
76
|
-
chains, draws, *_ = posterior_draws[0].shape
|
|
77
|
-
|
|
78
|
-
def make_rv_coords(name):
|
|
79
|
-
coords = {"chain": range(chains), "draw": range(draws)}
|
|
80
|
-
extra_dims = model.named_vars_to_dims.get(name)
|
|
81
|
-
if extra_dims is None:
|
|
82
|
-
return coords
|
|
83
|
-
return coords | {dim: list(model.coords[dim]) for dim in extra_dims}
|
|
84
|
-
|
|
85
|
-
def make_rv_dims(name):
|
|
86
|
-
dims = ["chain", "draw"]
|
|
87
|
-
extra_dims = model.named_vars_to_dims.get(name)
|
|
88
|
-
if extra_dims is None:
|
|
89
|
-
return dims
|
|
90
|
-
return dims + list(extra_dims)
|
|
91
|
-
|
|
92
|
-
names = [
|
|
93
|
-
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
|
|
94
|
-
]
|
|
95
|
-
idata = {
|
|
96
|
-
name: xr.DataArray(
|
|
97
|
-
data=draws,
|
|
98
|
-
coords=make_rv_coords(name),
|
|
99
|
-
dims=make_rv_dims(name),
|
|
100
|
-
name=name,
|
|
101
|
-
)
|
|
102
|
-
for name, draws in zip(names, posterior_draws)
|
|
103
|
-
}
|
|
104
|
-
|
|
105
|
-
coords, dims = coords_and_dims_for_inferencedata(model)
|
|
106
|
-
idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)
|
|
107
|
-
|
|
108
|
-
return idata
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def add_fit_to_inferencedata(
|
|
112
|
-
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
|
|
113
|
-
) -> az.InferenceData:
|
|
114
|
-
"""
|
|
115
|
-
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
Parameters
|
|
119
|
-
----------
|
|
120
|
-
idata: az.InfereceData
|
|
121
|
-
An InferenceData object containing the approximated posterior samples.
|
|
122
|
-
mu: RaveledVars
|
|
123
|
-
The MAP estimate of the model parameters.
|
|
124
|
-
H_inv: np.ndarray
|
|
125
|
-
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
126
|
-
model: Model, optional
|
|
127
|
-
A PyMC model. If None, the model is taken from the current model context.
|
|
128
|
-
|
|
129
|
-
Returns
|
|
130
|
-
-------
|
|
131
|
-
idata: az.InferenceData
|
|
132
|
-
The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group.
|
|
133
|
-
"""
|
|
134
|
-
model = pm.modelcontext(model)
|
|
135
|
-
coords = model.coords
|
|
136
|
-
|
|
137
|
-
variable_names, *_ = zip(*mu.point_map_info)
|
|
138
|
-
|
|
139
|
-
def make_unpacked_variable_names(name):
|
|
140
|
-
value_to_dim = {
|
|
141
|
-
x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None)
|
|
142
|
-
for x in model.value_vars
|
|
143
|
-
}
|
|
144
|
-
value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None}
|
|
145
|
-
|
|
146
|
-
rv_to_dim = model.named_vars_to_dims
|
|
147
|
-
dims_dict = rv_to_dim | value_to_dim
|
|
148
|
-
|
|
149
|
-
dims = dims_dict.get(name)
|
|
150
|
-
if dims is None:
|
|
151
|
-
return [name]
|
|
152
|
-
labels = product(*(coords[dim] for dim in dims))
|
|
153
|
-
return [f"{name}[{','.join(map(str, label))}]" for label in labels]
|
|
154
|
-
|
|
155
|
-
unpacked_variable_names = reduce(
|
|
156
|
-
lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, []
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names})
|
|
160
|
-
cov_dataarray = xr.DataArray(
|
|
161
|
-
H_inv,
|
|
162
|
-
dims=["rows", "columns"],
|
|
163
|
-
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
|
|
167
|
-
idata.add_groups(fit=dataset)
|
|
168
|
-
|
|
169
|
-
return idata
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def add_data_to_inferencedata(
|
|
173
|
-
idata: az.InferenceData,
|
|
174
|
-
progressbar: bool = True,
|
|
175
|
-
model: pm.Model | None = None,
|
|
176
|
-
compile_kwargs: dict | None = None,
|
|
177
|
-
) -> az.InferenceData:
|
|
178
|
-
"""
|
|
179
|
-
Add observed and constant data to an InferenceData object.
|
|
180
|
-
|
|
181
|
-
Parameters
|
|
182
|
-
----------
|
|
183
|
-
idata: az.InferenceData
|
|
184
|
-
An InferenceData object containing the approximated posterior samples.
|
|
185
|
-
progressbar: bool
|
|
186
|
-
Whether to display a progress bar during computations. Default is True.
|
|
187
|
-
model: Model, optional
|
|
188
|
-
A PyMC model. If None, the model is taken from the current model context.
|
|
189
|
-
compile_kwargs: dict, optional
|
|
190
|
-
Additional keyword arguments to pass to pytensor.function.
|
|
191
|
-
|
|
192
|
-
Returns
|
|
193
|
-
-------
|
|
194
|
-
idata: az.InferenceData
|
|
195
|
-
The provided InferenceData, with observed and constant data added.
|
|
196
|
-
"""
|
|
197
|
-
model = pm.modelcontext(model)
|
|
198
|
-
|
|
199
|
-
if model.deterministics:
|
|
200
|
-
idata.posterior = pm.compute_deterministics(
|
|
201
|
-
idata.posterior,
|
|
202
|
-
model=model,
|
|
203
|
-
merge_dataset=True,
|
|
204
|
-
progressbar=progressbar,
|
|
205
|
-
compile_kwargs=compile_kwargs,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
coords, dims = coords_and_dims_for_inferencedata(model)
|
|
209
|
-
|
|
210
|
-
observed_data = dict_to_dataset(
|
|
211
|
-
find_observations(model),
|
|
212
|
-
library=pm,
|
|
213
|
-
coords=coords,
|
|
214
|
-
dims=dims,
|
|
215
|
-
default_dims=[],
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
constant_data = dict_to_dataset(
|
|
219
|
-
find_constants(model),
|
|
220
|
-
library=pm,
|
|
221
|
-
coords=coords,
|
|
222
|
-
dims=dims,
|
|
223
|
-
default_dims=[],
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
idata.add_groups(
|
|
227
|
-
{"observed_data": observed_data, "constant_data": constant_data},
|
|
228
|
-
coords=coords,
|
|
229
|
-
dims=dims,
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
return idata
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def fit_mvn_at_MAP(
|
|
236
|
-
optimized_point: dict[str, np.ndarray],
|
|
237
|
-
model: pm.Model | None = None,
|
|
238
|
-
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
|
|
239
|
-
transform_samples: bool = False,
|
|
240
|
-
gradient_backend: GradientBackend = "pytensor",
|
|
241
|
-
zero_tol: float = 1e-8,
|
|
242
|
-
diag_jitter: float | None = 1e-8,
|
|
243
|
-
compile_kwargs: dict | None = None,
|
|
244
|
-
) -> tuple[RaveledVars, np.ndarray]:
|
|
245
|
-
"""
|
|
246
|
-
Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
|
|
247
|
-
evaluated at the MAP estimate. This is the basis of the Laplace approximation.
|
|
248
|
-
|
|
249
|
-
Parameters
|
|
250
|
-
----------
|
|
251
|
-
optimized_point : dict[str, np.ndarray]
|
|
252
|
-
Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
|
|
253
|
-
model : Model, optional
|
|
254
|
-
A PyMC model. If None, the model is taken from the current model context.
|
|
255
|
-
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
|
|
256
|
-
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
|
|
257
|
-
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
|
|
258
|
-
If 'error', an error will be raised.
|
|
259
|
-
transform_samples : bool
|
|
260
|
-
Whether to transform the samples back to the original parameter space. Default is True.
|
|
261
|
-
gradient_backend: str, default "pytensor"
|
|
262
|
-
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
263
|
-
zero_tol: float
|
|
264
|
-
Value below which an element of the Hessian matrix is counted as 0.
|
|
265
|
-
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
|
|
266
|
-
diag_jitter: float | None
|
|
267
|
-
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
|
|
268
|
-
If None, no jitter is added. Default is 1e-8.
|
|
269
|
-
compile_kwargs: dict, optional
|
|
270
|
-
Additional keyword arguments to pass to pytensor.function when compiling loss functions
|
|
271
|
-
|
|
272
|
-
Returns
|
|
273
|
-
-------
|
|
274
|
-
map_estimate: RaveledVars
|
|
275
|
-
The MAP estimate of the model parameters, raveled into a 1D array.
|
|
276
|
-
|
|
277
|
-
inverse_hessian: np.ndarray
|
|
278
|
-
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
279
|
-
"""
|
|
280
|
-
if gradient_backend == "jax" and not find_spec("jax"):
|
|
281
|
-
raise ImportError("JAX must be installed to use JAX gradients")
|
|
282
|
-
|
|
283
|
-
model = pm.modelcontext(model)
|
|
284
|
-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
285
|
-
frozen_model = freeze_dims_and_data(model)
|
|
286
|
-
|
|
287
|
-
if not transform_samples:
|
|
288
|
-
untransformed_model = remove_value_transforms(frozen_model)
|
|
289
|
-
logp = untransformed_model.logp(jacobian=False)
|
|
290
|
-
variables = untransformed_model.continuous_value_vars
|
|
291
|
-
else:
|
|
292
|
-
logp = frozen_model.logp(jacobian=True)
|
|
293
|
-
variables = frozen_model.continuous_value_vars
|
|
294
|
-
|
|
295
|
-
variable_names = {var.name for var in variables}
|
|
296
|
-
optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names}
|
|
297
|
-
mu = DictToArrayBijection.map(optimized_free_params)
|
|
298
|
-
|
|
299
|
-
_, f_hess, _ = scipy_optimize_funcs_from_loss(
|
|
300
|
-
loss=-logp,
|
|
301
|
-
inputs=variables,
|
|
302
|
-
initial_point_dict=optimized_free_params,
|
|
303
|
-
use_grad=True,
|
|
304
|
-
use_hess=True,
|
|
305
|
-
use_hessp=False,
|
|
306
|
-
gradient_backend=gradient_backend,
|
|
307
|
-
compile_kwargs=compile_kwargs,
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
H = -f_hess(mu.data)
|
|
311
|
-
H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
|
|
312
|
-
|
|
313
|
-
def stabilize(x, jitter):
|
|
314
|
-
return x + np.eye(x.shape[0]) * jitter
|
|
315
|
-
|
|
316
|
-
H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter)
|
|
317
|
-
|
|
318
|
-
try:
|
|
319
|
-
np.linalg.cholesky(H_inv)
|
|
320
|
-
except np.linalg.LinAlgError:
|
|
321
|
-
if on_bad_cov == "error":
|
|
322
|
-
raise np.linalg.LinAlgError(
|
|
323
|
-
"Inverse Hessian not positive-semi definite at the provided point"
|
|
324
|
-
)
|
|
325
|
-
H_inv = get_nearest_psd(H_inv)
|
|
326
|
-
if on_bad_cov == "warn":
|
|
327
|
-
_log.warning(
|
|
328
|
-
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
|
|
329
|
-
"matrix in L1-norm instead"
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
return mu, H_inv
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def sample_laplace_posterior(
|
|
336
|
-
mu: RaveledVars,
|
|
337
|
-
H_inv: np.ndarray,
|
|
338
|
-
model: pm.Model | None = None,
|
|
339
|
-
chains: int = 2,
|
|
340
|
-
draws: int = 500,
|
|
341
|
-
transform_samples: bool = False,
|
|
342
|
-
progressbar: bool = True,
|
|
343
|
-
random_seed: int | np.random.Generator | None = None,
|
|
344
|
-
compile_kwargs: dict | None = None,
|
|
345
|
-
) -> az.InferenceData:
|
|
346
|
-
"""
|
|
347
|
-
Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`.
|
|
348
|
-
|
|
349
|
-
Parameters
|
|
350
|
-
----------
|
|
351
|
-
mu: RaveledVars
|
|
352
|
-
The MAP estimate of the model parameters.
|
|
353
|
-
H_inv: np.ndarray
|
|
354
|
-
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
355
|
-
model : Model
|
|
356
|
-
A PyMC model
|
|
357
|
-
chains : int
|
|
358
|
-
The number of sampling chains running in parallel. Default is 2.
|
|
359
|
-
draws : int
|
|
360
|
-
The number of samples to draw from the approximated posterior. Default is 500.
|
|
361
|
-
transform_samples : bool
|
|
362
|
-
Whether to transform the samples back to the original parameter space. Default is True.
|
|
363
|
-
progressbar : bool
|
|
364
|
-
Whether to display a progress bar during computations. Default is True.
|
|
365
|
-
random_seed: int | np.random.Generator | None
|
|
366
|
-
Seed for the random number generator or a numpy Generator for reproducibility
|
|
367
|
-
|
|
368
|
-
Returns
|
|
369
|
-
-------
|
|
370
|
-
idata: az.InferenceData
|
|
371
|
-
An InferenceData object containing the approximated posterior samples.
|
|
372
|
-
"""
|
|
373
|
-
model = pm.modelcontext(model)
|
|
374
|
-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
375
|
-
rng = np.random.default_rng(random_seed)
|
|
376
|
-
|
|
377
|
-
posterior_dist = stats.multivariate_normal(
|
|
378
|
-
mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
posterior_draws = posterior_dist.rvs(size=(chains, draws))
|
|
382
|
-
if mu.data.shape == (1,):
|
|
383
|
-
posterior_draws = np.expand_dims(posterior_draws, -1)
|
|
384
|
-
|
|
385
|
-
if transform_samples:
|
|
386
|
-
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
|
|
387
|
-
batched_values = pt.tensor(
|
|
388
|
-
"batched_values",
|
|
389
|
-
shape=(chains, draws, *unconstrained_vector.type.shape),
|
|
390
|
-
dtype=unconstrained_vector.type.dtype,
|
|
391
|
-
)
|
|
392
|
-
batched_rvs = pytensor.graph.vectorize_graph(
|
|
393
|
-
constrained_rvs, replace={unconstrained_vector: batched_values}
|
|
394
|
-
)
|
|
395
|
-
|
|
396
|
-
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
|
|
397
|
-
posterior_draws = f_constrain(posterior_draws)
|
|
398
|
-
|
|
399
|
-
else:
|
|
400
|
-
info = mu.point_map_info
|
|
401
|
-
flat_shapes = [size for _, _, size, _ in info]
|
|
402
|
-
slices = [
|
|
403
|
-
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
|
|
404
|
-
]
|
|
405
|
-
|
|
406
|
-
posterior_draws = [
|
|
407
|
-
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
|
|
408
|
-
for idx, (name, shape, _, dtype) in zip(slices, info)
|
|
409
|
-
]
|
|
410
|
-
|
|
411
|
-
idata = laplace_draws_to_inferencedata(posterior_draws, model)
|
|
412
|
-
idata = add_fit_to_inferencedata(idata, mu, H_inv)
|
|
413
|
-
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
|
|
414
|
-
|
|
415
|
-
return idata
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
def fit_laplace(
|
|
419
|
-
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
|
|
420
|
-
*,
|
|
421
|
-
model: pm.Model | None = None,
|
|
422
|
-
use_grad: bool | None = None,
|
|
423
|
-
use_hessp: bool | None = None,
|
|
424
|
-
use_hess: bool | None = None,
|
|
425
|
-
initvals: dict | None = None,
|
|
426
|
-
random_seed: int | np.random.Generator | None = None,
|
|
427
|
-
return_raw: bool = False,
|
|
428
|
-
jitter_rvs: list[pt.TensorVariable] | None = None,
|
|
429
|
-
progressbar: bool = True,
|
|
430
|
-
include_transformed: bool = True,
|
|
431
|
-
gradient_backend: GradientBackend = "pytensor",
|
|
432
|
-
chains: int = 2,
|
|
433
|
-
draws: int = 500,
|
|
434
|
-
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
|
|
435
|
-
fit_in_unconstrained_space: bool = False,
|
|
436
|
-
zero_tol: float = 1e-8,
|
|
437
|
-
diag_jitter: float | None = 1e-8,
|
|
438
|
-
optimizer_kwargs: dict | None = None,
|
|
439
|
-
compile_kwargs: dict | None = None,
|
|
440
|
-
) -> az.InferenceData:
|
|
441
|
-
"""
|
|
442
|
-
Create a Laplace (quadratic) approximation for a posterior distribution.
|
|
443
|
-
|
|
444
|
-
This function generates a Laplace approximation for a given posterior distribution using a specified
|
|
445
|
-
number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
|
|
446
|
-
that can be used for further analysis.
|
|
447
|
-
|
|
448
|
-
Parameters
|
|
449
|
-
----------
|
|
450
|
-
model : pm.Model
|
|
451
|
-
The PyMC model to be fit. If None, the current model context is used.
|
|
452
|
-
method : str
|
|
453
|
-
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
454
|
-
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
455
|
-
|
|
456
|
-
See scipy.optimize.minimize documentation for details.
|
|
457
|
-
use_grad : bool | None, optional
|
|
458
|
-
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
459
|
-
the ``method``.
|
|
460
|
-
use_hessp : bool | None, optional
|
|
461
|
-
Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
|
|
462
|
-
the ``method``.
|
|
463
|
-
use_hess : bool | None, optional
|
|
464
|
-
Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
|
|
465
|
-
the ``method``.
|
|
466
|
-
initvals : None | dict, optional
|
|
467
|
-
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
468
|
-
If None, the model's default initial values are used.
|
|
469
|
-
random_seed : None | int | np.random.Generator, optional
|
|
470
|
-
Seed for the random number generator or a numpy Generator for reproducibility
|
|
471
|
-
return_raw: bool | False, optinal
|
|
472
|
-
Whether to also return the full output of `scipy.optimize.minimize`
|
|
473
|
-
jitter_rvs : list of TensorVariables, optional
|
|
474
|
-
Variables whose initial values should be jittered. If None, all variables are jittered.
|
|
475
|
-
progressbar : bool, optional
|
|
476
|
-
Whether to display a progress bar during optimization. Defaults to True.
|
|
477
|
-
fit_in_unconstrained_space: bool, default False
|
|
478
|
-
Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn
|
|
479
|
-
from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will
|
|
480
|
-
then be transformed back to the original parameter space. This will guarantee that the samples will respect
|
|
481
|
-
the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0
|
|
482
|
-
and 1).
|
|
483
|
-
|
|
484
|
-
.. warning::
|
|
485
|
-
This argument should be considered highly experimental. It has not been verified if this method produces
|
|
486
|
-
valid draws from the posterior. **Use at your own risk**.
|
|
487
|
-
|
|
488
|
-
gradient_backend: str, default "pytensor"
|
|
489
|
-
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
490
|
-
chains: int, default: 2
|
|
491
|
-
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
492
|
-
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
493
|
-
compatible with the ArviZ library.
|
|
494
|
-
draws: int, default: 500
|
|
495
|
-
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
|
|
496
|
-
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
|
|
497
|
-
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
|
|
498
|
-
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
|
|
499
|
-
If 'error', an error will be raised.
|
|
500
|
-
zero_tol: float
|
|
501
|
-
Value below which an element of the Hessian matrix is counted as 0.
|
|
502
|
-
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
|
|
503
|
-
diag_jitter: float | None
|
|
504
|
-
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
|
|
505
|
-
If None, no jitter is added. Default is 1e-8.
|
|
506
|
-
optimizer_kwargs
|
|
507
|
-
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
508
|
-
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
509
|
-
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
510
|
-
compile_kwargs: dict, optional
|
|
511
|
-
Additional keyword arguments to pass to pytensor.function.
|
|
512
|
-
|
|
513
|
-
Returns
|
|
514
|
-
-------
|
|
515
|
-
:class:`~arviz.InferenceData`
|
|
516
|
-
An InferenceData object containing the approximated posterior samples.
|
|
517
|
-
|
|
518
|
-
Examples
|
|
519
|
-
--------
|
|
520
|
-
>>> from pymc_extras.inference.laplace import fit_laplace
|
|
521
|
-
>>> import numpy as np
|
|
522
|
-
>>> import pymc as pm
|
|
523
|
-
>>> import arviz as az
|
|
524
|
-
>>> y = np.array([2642, 3503, 4358]*10)
|
|
525
|
-
>>> with pm.Model() as m:
|
|
526
|
-
>>> logsigma = pm.Uniform("logsigma", 1, 100)
|
|
527
|
-
>>> mu = pm.Uniform("mu", -10000, 10000)
|
|
528
|
-
>>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
|
|
529
|
-
>>> idata = fit_laplace()
|
|
530
|
-
|
|
531
|
-
Notes
|
|
532
|
-
-----
|
|
533
|
-
This method of approximation may not be suitable for all types of posterior distributions,
|
|
534
|
-
especially those with significant skewness or multimodality.
|
|
535
|
-
|
|
536
|
-
See Also
|
|
537
|
-
--------
|
|
538
|
-
fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m)
|
|
539
|
-
will forward the call to 'fit_laplace'.
|
|
540
|
-
|
|
541
|
-
"""
|
|
542
|
-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
543
|
-
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
|
|
544
|
-
|
|
545
|
-
optimized_point = find_MAP(
|
|
546
|
-
method=optimize_method,
|
|
547
|
-
model=model,
|
|
548
|
-
use_grad=use_grad,
|
|
549
|
-
use_hessp=use_hessp,
|
|
550
|
-
use_hess=use_hess,
|
|
551
|
-
initvals=initvals,
|
|
552
|
-
random_seed=random_seed,
|
|
553
|
-
return_raw=return_raw,
|
|
554
|
-
jitter_rvs=jitter_rvs,
|
|
555
|
-
progressbar=progressbar,
|
|
556
|
-
include_transformed=include_transformed,
|
|
557
|
-
gradient_backend=gradient_backend,
|
|
558
|
-
compile_kwargs=compile_kwargs,
|
|
559
|
-
**optimizer_kwargs,
|
|
560
|
-
)
|
|
561
|
-
|
|
562
|
-
mu, H_inv = fit_mvn_at_MAP(
|
|
563
|
-
optimized_point=optimized_point,
|
|
564
|
-
model=model,
|
|
565
|
-
on_bad_cov=on_bad_cov,
|
|
566
|
-
transform_samples=fit_in_unconstrained_space,
|
|
567
|
-
gradient_backend=gradient_backend,
|
|
568
|
-
zero_tol=zero_tol,
|
|
569
|
-
diag_jitter=diag_jitter,
|
|
570
|
-
compile_kwargs=compile_kwargs,
|
|
571
|
-
)
|
|
572
|
-
|
|
573
|
-
return sample_laplace_posterior(
|
|
574
|
-
mu=mu,
|
|
575
|
-
H_inv=H_inv,
|
|
576
|
-
model=model,
|
|
577
|
-
chains=chains,
|
|
578
|
-
draws=draws,
|
|
579
|
-
transform_samples=fit_in_unconstrained_space,
|
|
580
|
-
progressbar=progressbar,
|
|
581
|
-
random_seed=random_seed,
|
|
582
|
-
compile_kwargs=compile_kwargs,
|
|
583
|
-
)
|