pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/deserialize.py +10 -4
- pymc_extras/distributions/continuous.py +1 -1
- pymc_extras/distributions/histogram_utils.py +6 -4
- pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- pymc_extras/distributions/timeseries.py +14 -12
- pymc_extras/inference/dadvi/dadvi.py +149 -128
- pymc_extras/inference/laplace_approx/find_map.py +16 -39
- pymc_extras/inference/laplace_approx/idata.py +22 -4
- pymc_extras/inference/laplace_approx/laplace.py +196 -151
- pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras/inference/pathfinder/idata.py +517 -0
- pymc_extras/inference/pathfinder/pathfinder.py +71 -12
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/graph_analysis.py +2 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/model_builder.py +9 -4
- pymc_extras/prior.py +203 -8
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/core/statespace.py +2 -1
- pymc_extras/statespace/filters/distributions.py +15 -13
- pymc_extras/statespace/filters/kalman_filter.py +24 -22
- pymc_extras/statespace/filters/kalman_smoother.py +3 -5
- pymc_extras/statespace/filters/utilities.py +2 -5
- pymc_extras/statespace/models/DFM.py +12 -27
- pymc_extras/statespace/models/ETS.py +190 -198
- pymc_extras/statespace/models/SARIMAX.py +5 -17
- pymc_extras/statespace/models/VARMAX.py +15 -67
- pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- pymc_extras/statespace/models/structural/components/regression.py +4 -26
- pymc_extras/statespace/models/utilities.py +7 -0
- pymc_extras/utils/model_equivalence.py +2 -2
- pymc_extras/utils/prior.py +10 -14
- pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
- {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -22,10 +22,15 @@ def make_default_labels(name: str, shape: tuple[int, ...]) -> list:
|
|
|
22
22
|
return [list(range(dim)) for dim in shape]
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def make_unpacked_variable_names(
|
|
25
|
+
def make_unpacked_variable_names(
|
|
26
|
+
names: list[str], model: pm.Model, var_name_to_model_var: dict[str, str] | None = None
|
|
27
|
+
) -> list[str]:
|
|
26
28
|
coords = model.coords
|
|
27
29
|
initial_point = model.initial_point()
|
|
28
30
|
|
|
31
|
+
if var_name_to_model_var is None:
|
|
32
|
+
var_name_to_model_var = {}
|
|
33
|
+
|
|
29
34
|
value_to_dim = {
|
|
30
35
|
value.name: model.named_vars_to_dims.get(model.values_to_rvs[value].name, None)
|
|
31
36
|
for value in model.value_vars
|
|
@@ -37,6 +42,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
|
|
|
37
42
|
|
|
38
43
|
unpacked_variable_names = []
|
|
39
44
|
for name in names:
|
|
45
|
+
name = var_name_to_model_var.get(name, name)
|
|
40
46
|
shape = initial_point[name].shape
|
|
41
47
|
if shape:
|
|
42
48
|
dims = dims_dict.get(name)
|
|
@@ -109,7 +115,7 @@ def map_results_to_inference_data(
|
|
|
109
115
|
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
110
116
|
]
|
|
111
117
|
|
|
112
|
-
unconstrained_names = set(all_varnames) - set(constrained_names)
|
|
118
|
+
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
|
|
113
119
|
|
|
114
120
|
idata = az.from_dict(
|
|
115
121
|
posterior={
|
|
@@ -258,6 +264,7 @@ def optimizer_result_to_dataset(
|
|
|
258
264
|
method: minimize_method | Literal["basinhopping"],
|
|
259
265
|
mu: RaveledVars | None = None,
|
|
260
266
|
model: pm.Model | None = None,
|
|
267
|
+
var_name_to_model_var: dict[str, str] | None = None,
|
|
261
268
|
) -> xr.Dataset:
|
|
262
269
|
"""
|
|
263
270
|
Convert an OptimizeResult object to an xarray Dataset object.
|
|
@@ -268,6 +275,9 @@ def optimizer_result_to_dataset(
|
|
|
268
275
|
The result of the optimization process.
|
|
269
276
|
method: minimize_method or "basinhopping"
|
|
270
277
|
The optimization method used.
|
|
278
|
+
var_name_to_model_var: dict, optional
|
|
279
|
+
Mapping between variables in the optimization result and the model variable names. Used when auxiliary
|
|
280
|
+
variables were introduced, e.g. in DADVI.
|
|
271
281
|
|
|
272
282
|
Returns
|
|
273
283
|
-------
|
|
@@ -279,7 +289,9 @@ def optimizer_result_to_dataset(
|
|
|
279
289
|
|
|
280
290
|
model = pm.modelcontext(model) if model is None else model
|
|
281
291
|
variable_names, *_ = zip(*mu.point_map_info)
|
|
282
|
-
unpacked_variable_names = make_unpacked_variable_names(
|
|
292
|
+
unpacked_variable_names = make_unpacked_variable_names(
|
|
293
|
+
variable_names, model, var_name_to_model_var
|
|
294
|
+
)
|
|
283
295
|
|
|
284
296
|
data_vars = {}
|
|
285
297
|
|
|
@@ -368,6 +380,7 @@ def add_optimizer_result_to_inference_data(
|
|
|
368
380
|
method: minimize_method | Literal["basinhopping"],
|
|
369
381
|
mu: RaveledVars | None = None,
|
|
370
382
|
model: pm.Model | None = None,
|
|
383
|
+
var_name_to_model_var: dict[str, str] | None = None,
|
|
371
384
|
) -> az.InferenceData:
|
|
372
385
|
"""
|
|
373
386
|
Add the optimization result to an InferenceData object.
|
|
@@ -384,13 +397,18 @@ def add_optimizer_result_to_inference_data(
|
|
|
384
397
|
The MAP estimate of the model parameters.
|
|
385
398
|
model: Model, optional
|
|
386
399
|
A PyMC model. If None, the model is taken from the current model context.
|
|
400
|
+
var_name_to_model_var: dict, optional
|
|
401
|
+
Mapping between variables in the optimization result and the model variable names. Used when auxiliary
|
|
402
|
+
variables were introduced, e.g. in DADVI.
|
|
387
403
|
|
|
388
404
|
Returns
|
|
389
405
|
-------
|
|
390
406
|
idata: az.InferenceData
|
|
391
407
|
The provided InferenceData, with the optimization results added to the "optimizer" group.
|
|
392
408
|
"""
|
|
393
|
-
dataset = optimizer_result_to_dataset(
|
|
409
|
+
dataset = optimizer_result_to_dataset(
|
|
410
|
+
result, method=method, mu=mu, model=model, var_name_to_model_var=var_name_to_model_var
|
|
411
|
+
)
|
|
394
412
|
idata.add_groups({"optimizer_result": dataset})
|
|
395
413
|
|
|
396
414
|
return idata
|
|
@@ -16,9 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from collections.abc import Callable
|
|
19
|
-
from functools import partial
|
|
20
19
|
from typing import Literal
|
|
21
|
-
from typing import cast as type_cast
|
|
22
20
|
|
|
23
21
|
import arviz as az
|
|
24
22
|
import numpy as np
|
|
@@ -27,16 +25,18 @@ import pytensor
|
|
|
27
25
|
import pytensor.tensor as pt
|
|
28
26
|
import xarray as xr
|
|
29
27
|
|
|
28
|
+
from arviz import dict_to_dataset
|
|
30
29
|
from better_optimize.constants import minimize_method
|
|
31
30
|
from numpy.typing import ArrayLike
|
|
31
|
+
from pymc import Model
|
|
32
|
+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
|
|
32
33
|
from pymc.blocking import DictToArrayBijection
|
|
33
34
|
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
34
|
-
from pymc.
|
|
35
|
-
from pymc.util import get_default_varnames
|
|
35
|
+
from pymc.util import get_untransformed_name, is_transformed_name
|
|
36
36
|
from pytensor.graph import vectorize_graph
|
|
37
37
|
from pytensor.tensor import TensorVariable
|
|
38
38
|
from pytensor.tensor.optimize import minimize
|
|
39
|
-
from
|
|
39
|
+
from xarray import Dataset
|
|
40
40
|
|
|
41
41
|
from pymc_extras.inference.laplace_approx.find_map import (
|
|
42
42
|
_compute_inverse_hessian,
|
|
@@ -147,130 +147,175 @@ def get_conditional_gaussian_approximation(
|
|
|
147
147
|
return pytensor.function(args, [x0, conditional_gaussian_approx])
|
|
148
148
|
|
|
149
149
|
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
]
|
|
155
|
-
names = [x.name for x in outputs]
|
|
150
|
+
def unpack_last_axis(packed_input, packed_shapes):
|
|
151
|
+
if len(packed_shapes) == 1:
|
|
152
|
+
# Single case currently fails in unpack
|
|
153
|
+
return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
|
|
156
154
|
|
|
157
|
-
|
|
155
|
+
keep_axes = tuple(range(packed_input.ndim))[:-1]
|
|
156
|
+
return pt.unpack(packed_input, axes=keep_axes, packed_shapes=packed_shapes)
|
|
158
157
|
|
|
159
|
-
new_outputs, unconstrained_vector = join_nonshared_inputs(
|
|
160
|
-
model.initial_point(),
|
|
161
|
-
inputs=model.value_vars,
|
|
162
|
-
outputs=outputs,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
|
|
166
|
-
value_rvs = [x for x in new_outputs if x not in constrained_rvs]
|
|
167
|
-
|
|
168
|
-
unconstrained_vector.name = "unconstrained_vector"
|
|
169
158
|
|
|
170
|
-
|
|
171
|
-
|
|
159
|
+
def draws_from_laplace_approx(
|
|
160
|
+
*,
|
|
161
|
+
mean,
|
|
162
|
+
covariance=None,
|
|
163
|
+
standard_deviation=None,
|
|
164
|
+
draws: int,
|
|
165
|
+
model: Model,
|
|
166
|
+
vectorize_draws: bool = True,
|
|
167
|
+
return_unconstrained: bool = True,
|
|
168
|
+
random_seed=None,
|
|
169
|
+
compile_kwargs: dict | None = None,
|
|
170
|
+
) -> tuple[Dataset, Dataset | None]:
|
|
171
|
+
"""
|
|
172
|
+
Generate draws from the Laplace approximation of the posterior.
|
|
172
173
|
|
|
173
|
-
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
mean : np.ndarray
|
|
177
|
+
The mean of the Laplace approximation (MAP estimate).
|
|
178
|
+
covariance : np.ndarray, optional
|
|
179
|
+
The covariance matrix of the Laplace approximation.
|
|
180
|
+
Mutually exclusive with `standard_deviation`.
|
|
181
|
+
standard_deviation : np.ndarray, optional
|
|
182
|
+
The standard deviation of the Laplace approximation (diagonal approximation).
|
|
183
|
+
Mutually exclusive with `covariance`.
|
|
184
|
+
draws : int
|
|
185
|
+
The number of draws.
|
|
186
|
+
model : pm.Model
|
|
187
|
+
The PyMC model.
|
|
188
|
+
vectorize_draws : bool, default True
|
|
189
|
+
Whether to vectorize the draws.
|
|
190
|
+
return_unconstrained : bool, default True
|
|
191
|
+
Whether to return the unconstrained draws in addition to the constrained ones.
|
|
192
|
+
random_seed : int, optional
|
|
193
|
+
Random seed for reproducibility.
|
|
194
|
+
compile_kwargs: dict, optional
|
|
195
|
+
Optional compile kwargs
|
|
174
196
|
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
tuple[Dataset, Dataset | None]
|
|
200
|
+
A tuple containing the constrained draws (trace) and optionally the unconstrained draws.
|
|
201
|
+
|
|
202
|
+
Raises
|
|
203
|
+
------
|
|
204
|
+
ValueError
|
|
205
|
+
If neither `covariance` nor `standard_deviation` is provided,
|
|
206
|
+
or if both are provided.
|
|
207
|
+
"""
|
|
208
|
+
# This function assumes that mean/covariance/standard_deviation are aligned with model.initial_point()
|
|
209
|
+
if covariance is None and standard_deviation is None:
|
|
210
|
+
raise ValueError("Must specify either covariance or standard_deviation")
|
|
211
|
+
if covariance is not None and standard_deviation is not None:
|
|
212
|
+
raise ValueError("Cannot specify both covariance and standard_deviation")
|
|
213
|
+
if compile_kwargs is None:
|
|
214
|
+
compile_kwargs = {}
|
|
175
215
|
|
|
176
|
-
def model_to_laplace_approx(
|
|
177
|
-
model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
|
|
178
|
-
):
|
|
179
216
|
initial_point = model.initial_point()
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
217
|
+
n = int(np.sum([np.prod(v.shape) for v in initial_point.values()]))
|
|
218
|
+
assert mean.shape == (n,)
|
|
219
|
+
if covariance is not None:
|
|
220
|
+
assert covariance.shape == (n, n)
|
|
221
|
+
elif standard_deviation is not None:
|
|
222
|
+
assert standard_deviation.shape == (n,)
|
|
223
|
+
|
|
224
|
+
vars_to_sample = [v for v in model.free_RVs + model.deterministics]
|
|
225
|
+
var_names = [v.name for v in vars_to_sample]
|
|
226
|
+
|
|
227
|
+
orig_constrained_vars = model.value_vars
|
|
228
|
+
orig_outputs = model.replace_rvs_by_values(vars_to_sample)
|
|
229
|
+
if return_unconstrained:
|
|
230
|
+
orig_outputs.extend(model.value_vars)
|
|
231
|
+
|
|
232
|
+
mu_pt = pt.vector("mu", shape=(n,), dtype=mean.dtype)
|
|
233
|
+
size = (draws,) if vectorize_draws else ()
|
|
234
|
+
if covariance is not None:
|
|
235
|
+
sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
|
|
236
|
+
laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
|
|
237
|
+
else:
|
|
238
|
+
sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
|
|
239
|
+
laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
|
|
240
|
+
|
|
241
|
+
constrained_vars = unpack_last_axis(
|
|
242
|
+
laplace_approximation,
|
|
243
|
+
[initial_point[v.name].shape for v in orig_constrained_vars],
|
|
244
|
+
)
|
|
245
|
+
outputs = vectorize_graph(
|
|
246
|
+
orig_outputs, replace=dict(zip(orig_constrained_vars, constrained_vars))
|
|
187
247
|
)
|
|
188
248
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
elif name in model.named_vars_to_dims:
|
|
217
|
-
dims = (*batch_dims, *model.named_vars_to_dims[name])
|
|
218
|
-
else:
|
|
219
|
-
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
|
|
220
|
-
initval = initial_point.get(name, None)
|
|
221
|
-
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
|
|
222
|
-
laplace_model.add_coords(
|
|
223
|
-
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
pm.Deterministic(name, batched_rv, dims=dims)
|
|
227
|
-
|
|
228
|
-
return laplace_model
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
|
|
232
|
-
"""
|
|
233
|
-
The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
|
|
234
|
-
in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
|
|
235
|
-
single vector, it's not easy to work with.
|
|
236
|
-
|
|
237
|
-
This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
|
|
238
|
-
coordinates, where possible.
|
|
239
|
-
"""
|
|
240
|
-
initial_point = DictToArrayBijection.map(model.initial_point())
|
|
241
|
-
|
|
242
|
-
cursor = 0
|
|
243
|
-
unstacked_laplace_draws = {}
|
|
244
|
-
coords = model.coords | {"chain": range(chains), "draw": range(draws)}
|
|
245
|
-
|
|
246
|
-
# There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
|
|
247
|
-
# simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
|
|
248
|
-
# add an arviz-style default dim and label.
|
|
249
|
-
for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
|
|
250
|
-
rv_dims = []
|
|
251
|
-
for i, dim in enumerate(
|
|
252
|
-
model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
|
|
253
|
-
):
|
|
254
|
-
if coords.get(dim) and shape[i] == len(coords[dim]):
|
|
255
|
-
rv_dims.append(dim)
|
|
256
|
-
else:
|
|
257
|
-
rv_dims.append(f"{name}_dim_{i}")
|
|
258
|
-
coords[f"{name}_dim_{i}"] = np.arange(shape[i])
|
|
259
|
-
|
|
260
|
-
dims = ("chain", "draw", *rv_dims)
|
|
261
|
-
|
|
262
|
-
values = (
|
|
263
|
-
laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
|
|
249
|
+
fn = pm.pytensorf.compile(
|
|
250
|
+
[mu_pt, sigma_pt],
|
|
251
|
+
outputs,
|
|
252
|
+
random_seed=random_seed,
|
|
253
|
+
trust_input=True,
|
|
254
|
+
**compile_kwargs,
|
|
255
|
+
)
|
|
256
|
+
sigma = covariance if covariance is not None else standard_deviation
|
|
257
|
+
if vectorize_draws:
|
|
258
|
+
output_buffers = fn(mean, sigma)
|
|
259
|
+
else:
|
|
260
|
+
# Take one draw to find the shape of the outputs
|
|
261
|
+
output_buffers = []
|
|
262
|
+
for out_draw in fn(mean, sigma):
|
|
263
|
+
output_buffer = np.empty((draws, *out_draw.shape), dtype=out_draw.dtype)
|
|
264
|
+
output_buffer[0] = out_draw
|
|
265
|
+
output_buffers.append(output_buffer)
|
|
266
|
+
# Fill one draws at a time
|
|
267
|
+
for i in range(1, draws):
|
|
268
|
+
for out_buffer, out_draw in zip(output_buffers, fn(mean, sigma)):
|
|
269
|
+
out_buffer[i] = out_draw
|
|
270
|
+
|
|
271
|
+
model_coords, model_dims = coords_and_dims_for_inferencedata(model)
|
|
272
|
+
posterior = {
|
|
273
|
+
var_name: out_buffer[None]
|
|
274
|
+
for var_name, out_buffer in (
|
|
275
|
+
zip(var_names, output_buffers, strict=not return_unconstrained)
|
|
264
276
|
)
|
|
265
|
-
|
|
266
|
-
|
|
277
|
+
}
|
|
278
|
+
posterior_dataset = dict_to_dataset(posterior, coords=model_coords, dims=model_dims, library=pm)
|
|
279
|
+
unconstrained_posterior_dataset = None
|
|
280
|
+
|
|
281
|
+
if return_unconstrained:
|
|
282
|
+
unconstrained_posterior = {
|
|
283
|
+
var.name: out_buffer[None]
|
|
284
|
+
for var, out_buffer in zip(
|
|
285
|
+
model.value_vars, output_buffers[len(posterior) :], strict=True
|
|
286
|
+
)
|
|
287
|
+
}
|
|
288
|
+
# Attempt to map constrained dims to unconstrained dims
|
|
289
|
+
for var_name, var_draws in unconstrained_posterior.items():
|
|
290
|
+
if not is_transformed_name(var_name):
|
|
291
|
+
# constrained == unconstrained, dims already shared
|
|
292
|
+
continue
|
|
293
|
+
constrained_dims = model_dims.get(get_untransformed_name(var_name))
|
|
294
|
+
if constrained_dims is None or (len(constrained_dims) != (var_draws.ndim - 2)):
|
|
295
|
+
continue
|
|
296
|
+
# Reuse dims from constrained variable if they match in length with unconstrained draws
|
|
297
|
+
inferred_dims = []
|
|
298
|
+
for i, (constrained_dim, unconstrained_dim_length) in enumerate(
|
|
299
|
+
zip(constrained_dims, var_draws.shape[2:], strict=True)
|
|
300
|
+
):
|
|
301
|
+
if model_coords.get(constrained_dim) is not None and (
|
|
302
|
+
len(model_coords[constrained_dim]) == unconstrained_dim_length
|
|
303
|
+
):
|
|
304
|
+
# Assume coordinates map. This could be fooled, by e.g., having a transform that reverses values
|
|
305
|
+
inferred_dims.append(constrained_dim)
|
|
306
|
+
else:
|
|
307
|
+
# Size mismatch (e.g., Simplex), make no assumption about mapping
|
|
308
|
+
inferred_dims.append(f"{var_name}_dim_{i}")
|
|
309
|
+
model_dims[var_name] = inferred_dims
|
|
310
|
+
|
|
311
|
+
unconstrained_posterior_dataset = dict_to_dataset(
|
|
312
|
+
unconstrained_posterior,
|
|
313
|
+
coords=model_coords,
|
|
314
|
+
dims=model_dims,
|
|
315
|
+
library=pm,
|
|
267
316
|
)
|
|
268
317
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
|
|
272
|
-
|
|
273
|
-
return unstacked_laplace_draws
|
|
318
|
+
return posterior_dataset, unconstrained_posterior_dataset
|
|
274
319
|
|
|
275
320
|
|
|
276
321
|
def fit_laplace(
|
|
@@ -285,9 +330,11 @@ def fit_laplace(
|
|
|
285
330
|
jitter_rvs: list[pt.TensorVariable] | None = None,
|
|
286
331
|
progressbar: bool = True,
|
|
287
332
|
include_transformed: bool = True,
|
|
333
|
+
freeze_model: bool = True,
|
|
288
334
|
gradient_backend: GradientBackend = "pytensor",
|
|
289
|
-
chains: int =
|
|
335
|
+
chains: None | int = None,
|
|
290
336
|
draws: int = 500,
|
|
337
|
+
vectorize_draws: bool = True,
|
|
291
338
|
optimizer_kwargs: dict | None = None,
|
|
292
339
|
compile_kwargs: dict | None = None,
|
|
293
340
|
) -> az.InferenceData:
|
|
@@ -328,18 +375,20 @@ def fit_laplace(
|
|
|
328
375
|
include_transformed: bool, default True
|
|
329
376
|
Whether to include transformed variables in the output. If True, transformed variables will be included in the
|
|
330
377
|
output InferenceData object. If False, only the original variables will be included.
|
|
378
|
+
freeze_model: bool, optional
|
|
379
|
+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
|
|
380
|
+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
|
|
381
|
+
True.
|
|
331
382
|
gradient_backend: str, default "pytensor"
|
|
332
383
|
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
333
|
-
chains: int, default: 2
|
|
334
|
-
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
335
|
-
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
336
|
-
compatible with the ArviZ library.
|
|
337
384
|
draws: int, default: 500
|
|
338
|
-
The number of samples to draw from the approximated posterior.
|
|
385
|
+
The number of samples to draw from the approximated posterior.
|
|
339
386
|
optimizer_kwargs
|
|
340
387
|
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
341
388
|
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
342
389
|
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
390
|
+
vectorize_draws: bool, default True
|
|
391
|
+
Whether to natively vectorize the random function or take one at a time in a python loop.
|
|
343
392
|
compile_kwargs: dict, optional
|
|
344
393
|
Additional keyword arguments to pass to pytensor.function.
|
|
345
394
|
|
|
@@ -354,7 +403,7 @@ def fit_laplace(
|
|
|
354
403
|
>>> import numpy as np
|
|
355
404
|
>>> import pymc as pm
|
|
356
405
|
>>> import arviz as az
|
|
357
|
-
>>> y = np.array([2642, 3503, 4358]*10)
|
|
406
|
+
>>> y = np.array([2642, 3503, 4358] * 10)
|
|
358
407
|
>>> with pm.Model() as m:
|
|
359
408
|
>>> logsigma = pm.Uniform("logsigma", 1, 100)
|
|
360
409
|
>>> mu = pm.Uniform("mu", -10000, 10000)
|
|
@@ -372,10 +421,19 @@ def fit_laplace(
|
|
|
372
421
|
will forward the call to 'fit_laplace'.
|
|
373
422
|
|
|
374
423
|
"""
|
|
424
|
+
if chains is not None:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
"chains argument has been deprecated. "
|
|
427
|
+
"The behavior can be recreated by unstacking draws into multiple chains after fitting"
|
|
428
|
+
)
|
|
429
|
+
|
|
375
430
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
376
431
|
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
|
|
377
432
|
model = pm.modelcontext(model) if model is None else model
|
|
378
433
|
|
|
434
|
+
if freeze_model:
|
|
435
|
+
model = freeze_dims_and_data(model)
|
|
436
|
+
|
|
379
437
|
idata = find_MAP(
|
|
380
438
|
method=optimize_method,
|
|
381
439
|
model=model,
|
|
@@ -387,17 +445,17 @@ def fit_laplace(
|
|
|
387
445
|
jitter_rvs=jitter_rvs,
|
|
388
446
|
progressbar=progressbar,
|
|
389
447
|
include_transformed=include_transformed,
|
|
448
|
+
freeze_model=False,
|
|
390
449
|
gradient_backend=gradient_backend,
|
|
391
450
|
compile_kwargs=compile_kwargs,
|
|
392
451
|
compute_hessian=True,
|
|
393
452
|
**optimizer_kwargs,
|
|
394
453
|
)
|
|
395
454
|
|
|
396
|
-
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
397
|
-
|
|
398
455
|
if "covariance_matrix" not in idata.fit:
|
|
399
456
|
# The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
|
|
400
457
|
# we have to go back and compute the Hessian at the MAP point now.
|
|
458
|
+
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
401
459
|
frozen_model = freeze_dims_and_data(model)
|
|
402
460
|
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
|
|
403
461
|
|
|
@@ -426,29 +484,16 @@ def fit_laplace(
|
|
|
426
484
|
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
|
|
427
485
|
)
|
|
428
486
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
.drop_vars(["chain", "draw"])
|
|
442
|
-
.rename({"temp_chain": "chain", "temp_draw": "draw"})
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
if include_transformed:
|
|
446
|
-
idata.unconstrained_posterior = unstack_laplace_draws(
|
|
447
|
-
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
|
|
448
|
-
)
|
|
449
|
-
|
|
450
|
-
idata.posterior = new_posterior.drop_vars(
|
|
451
|
-
["laplace_approximation", "unpacked_variable_names"]
|
|
452
|
-
)
|
|
453
|
-
|
|
487
|
+
# We override the posterior/unconstrained_posterior from find_MAP
|
|
488
|
+
idata.posterior, unconstrained_posterior = draws_from_laplace_approx(
|
|
489
|
+
mean=idata.fit["mean_vector"].values,
|
|
490
|
+
covariance=idata.fit["covariance_matrix"].values,
|
|
491
|
+
draws=draws,
|
|
492
|
+
return_unconstrained=include_transformed,
|
|
493
|
+
model=model,
|
|
494
|
+
vectorize_draws=vectorize_draws,
|
|
495
|
+
random_seed=random_seed,
|
|
496
|
+
)
|
|
497
|
+
if include_transformed:
|
|
498
|
+
idata.unconstrained_posterior = unconstrained_posterior
|
|
454
499
|
return idata
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
1
3
|
from collections.abc import Callable
|
|
2
4
|
from importlib.util import find_spec
|
|
3
5
|
from typing import Literal, get_args
|
|
@@ -6,6 +8,7 @@ import numpy as np
|
|
|
6
8
|
import pymc as pm
|
|
7
9
|
import pytensor
|
|
8
10
|
|
|
11
|
+
from better_optimize.constants import MINIMIZE_MODE_KWARGS
|
|
9
12
|
from pymc import join_nonshared_inputs
|
|
10
13
|
from pytensor import tensor as pt
|
|
11
14
|
from pytensor.compile import Function
|
|
@@ -14,6 +17,39 @@ from pytensor.tensor import TensorVariable
|
|
|
14
17
|
GradientBackend = Literal["pytensor", "jax"]
|
|
15
18
|
VALID_BACKENDS = get_args(GradientBackend)
|
|
16
19
|
|
|
20
|
+
_log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_optimizer_function_defaults(
|
|
24
|
+
method: str, use_grad: bool | None, use_hess: bool | None, use_hessp: bool | None
|
|
25
|
+
):
|
|
26
|
+
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
27
|
+
|
|
28
|
+
if use_hess and use_hessp:
|
|
29
|
+
_log.warning(
|
|
30
|
+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
31
|
+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
32
|
+
'Setting "use_hess" to False.'
|
|
33
|
+
)
|
|
34
|
+
use_hess = False
|
|
35
|
+
|
|
36
|
+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
37
|
+
|
|
38
|
+
if use_hessp is not None and use_hess is None:
|
|
39
|
+
use_hess = not use_hessp
|
|
40
|
+
|
|
41
|
+
elif use_hess is not None and use_hessp is None:
|
|
42
|
+
use_hessp = not use_hess
|
|
43
|
+
|
|
44
|
+
elif use_hessp is None and use_hess is None:
|
|
45
|
+
use_hessp = method_info["uses_hessp"]
|
|
46
|
+
use_hess = method_info["uses_hess"]
|
|
47
|
+
if use_hessp and use_hess:
|
|
48
|
+
# If a method could use either hess or hessp, we default to using hessp
|
|
49
|
+
use_hess = False
|
|
50
|
+
|
|
51
|
+
return use_grad, use_hess, use_hessp
|
|
52
|
+
|
|
17
53
|
|
|
18
54
|
def _compile_grad_and_hess_to_jax(
|
|
19
55
|
f_fused: Function, use_hess: bool, use_hessp: bool
|
|
@@ -144,12 +180,13 @@ def _compile_functions_for_scipy_optimize(
|
|
|
144
180
|
def scipy_optimize_funcs_from_loss(
|
|
145
181
|
loss: TensorVariable,
|
|
146
182
|
inputs: list[TensorVariable],
|
|
147
|
-
initial_point_dict: dict[str, np.ndarray | float | int],
|
|
148
|
-
use_grad: bool,
|
|
149
|
-
use_hess: bool,
|
|
150
|
-
use_hessp: bool,
|
|
183
|
+
initial_point_dict: dict[str, np.ndarray | float | int] | None = None,
|
|
184
|
+
use_grad: bool | None = None,
|
|
185
|
+
use_hess: bool | None = None,
|
|
186
|
+
use_hessp: bool | None = None,
|
|
151
187
|
gradient_backend: GradientBackend = "pytensor",
|
|
152
188
|
compile_kwargs: dict | None = None,
|
|
189
|
+
inputs_are_flat: bool = False,
|
|
153
190
|
) -> tuple[Callable, ...]:
|
|
154
191
|
"""
|
|
155
192
|
Compile loss functions for use with scipy.optimize.minimize.
|
|
@@ -206,9 +243,12 @@ def scipy_optimize_funcs_from_loss(
|
|
|
206
243
|
if not isinstance(inputs, list):
|
|
207
244
|
inputs = [inputs]
|
|
208
245
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
246
|
+
if inputs_are_flat:
|
|
247
|
+
[flat_input] = inputs
|
|
248
|
+
else:
|
|
249
|
+
[loss], flat_input = join_nonshared_inputs(
|
|
250
|
+
point=initial_point_dict, outputs=[loss], inputs=inputs
|
|
251
|
+
)
|
|
212
252
|
|
|
213
253
|
# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
|
|
214
254
|
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
|