pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/distributions/timeseries.py +10 -10
- pymc_extras/inference/dadvi/dadvi.py +14 -83
- pymc_extras/inference/laplace_approx/laplace.py +187 -159
- pymc_extras/inference/pathfinder/pathfinder.py +12 -7
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/prior.py +3 -3
- pymc_extras/statespace/core/properties.py +276 -0
- pymc_extras/statespace/core/statespace.py +182 -45
- pymc_extras/statespace/filters/distributions.py +19 -34
- pymc_extras/statespace/filters/kalman_filter.py +13 -12
- pymc_extras/statespace/filters/kalman_smoother.py +2 -2
- pymc_extras/statespace/models/DFM.py +179 -168
- pymc_extras/statespace/models/ETS.py +177 -151
- pymc_extras/statespace/models/SARIMAX.py +149 -152
- pymc_extras/statespace/models/VARMAX.py +134 -145
- pymc_extras/statespace/models/__init__.py +8 -1
- pymc_extras/statespace/models/structural/__init__.py +30 -8
- pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
- pymc_extras/statespace/models/structural/components/cycle.py +119 -80
- pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
- pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
- pymc_extras/statespace/models/structural/components/regression.py +105 -68
- pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
- pymc_extras/statespace/models/structural/core.py +397 -286
- pymc_extras/statespace/models/utilities.py +5 -20
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -196,21 +196,20 @@ class DiscreteMarkovChain(Distribution):
|
|
|
196
196
|
state_rng = pytensor.shared(np.random.default_rng())
|
|
197
197
|
|
|
198
198
|
def transition(*args):
|
|
199
|
-
*states, transition_probs
|
|
199
|
+
old_rng, *states, transition_probs = args
|
|
200
200
|
p = transition_probs[tuple(states)]
|
|
201
201
|
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
|
|
202
|
-
return
|
|
202
|
+
return next_rng, next_state
|
|
203
203
|
|
|
204
|
-
|
|
204
|
+
state_next_rng, markov_chain = pytensor.scan(
|
|
205
205
|
transition,
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
|
|
207
|
+
non_sequences=[P_],
|
|
208
208
|
n_steps=steps_,
|
|
209
209
|
strict=True,
|
|
210
|
+
return_updates=False,
|
|
210
211
|
)
|
|
211
212
|
|
|
212
|
-
(state_next_rng,) = tuple(state_updates.values())
|
|
213
|
-
|
|
214
213
|
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
|
|
215
214
|
|
|
216
215
|
discrete_mc_op = DiscreteMarkovChainRV(
|
|
@@ -239,16 +238,17 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
|
|
|
239
238
|
n_lags = op.n_lags
|
|
240
239
|
|
|
241
240
|
def greedy_transition(*args):
|
|
242
|
-
*states, transition_probs
|
|
241
|
+
*states, transition_probs = args
|
|
243
242
|
p = transition_probs[tuple(states)]
|
|
244
243
|
return pt.argmax(p)
|
|
245
244
|
|
|
246
|
-
chain_moment
|
|
245
|
+
chain_moment = pytensor.scan(
|
|
247
246
|
greedy_transition,
|
|
248
|
-
non_sequences=[P
|
|
247
|
+
non_sequences=[P],
|
|
249
248
|
outputs_info=_make_outputs_info(n_lags, init_dist),
|
|
250
249
|
n_steps=steps,
|
|
251
250
|
strict=True,
|
|
251
|
+
return_updates=False,
|
|
252
252
|
)
|
|
253
253
|
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
|
|
254
254
|
return chain_moment
|
|
@@ -3,25 +3,20 @@ import numpy as np
|
|
|
3
3
|
import pymc
|
|
4
4
|
import pytensor
|
|
5
5
|
import pytensor.tensor as pt
|
|
6
|
-
import xarray
|
|
7
6
|
|
|
7
|
+
from arviz import InferenceData
|
|
8
8
|
from better_optimize import basinhopping, minimize
|
|
9
9
|
from better_optimize.constants import minimize_method
|
|
10
10
|
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
|
|
11
|
-
from pymc.backends.arviz import (
|
|
12
|
-
PointFunc,
|
|
13
|
-
apply_function_over_dataset,
|
|
14
|
-
coords_and_dims_for_inferencedata,
|
|
15
|
-
)
|
|
16
11
|
from pymc.blocking import RaveledVars
|
|
17
|
-
from pymc.util import RandomSeed
|
|
12
|
+
from pymc.util import RandomSeed
|
|
18
13
|
from pytensor.tensor.variable import TensorVariable
|
|
19
14
|
|
|
20
15
|
from pymc_extras.inference.laplace_approx.idata import (
|
|
21
16
|
add_data_to_inference_data,
|
|
22
17
|
add_optimizer_result_to_inference_data,
|
|
23
18
|
)
|
|
24
|
-
from pymc_extras.inference.laplace_approx.laplace import
|
|
19
|
+
from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
|
|
25
20
|
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
26
21
|
scipy_optimize_funcs_from_loss,
|
|
27
22
|
set_optimizer_function_defaults,
|
|
@@ -193,16 +188,18 @@ def fit_dadvi(
|
|
|
193
188
|
opt_var_params = result.x
|
|
194
189
|
opt_means, opt_log_sds = np.split(opt_var_params, 2)
|
|
195
190
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
|
|
191
|
+
posterior, unconstrained_posterior = draws_from_laplace_approx(
|
|
192
|
+
mean=opt_means,
|
|
193
|
+
standard_deviation=np.exp(opt_log_sds),
|
|
194
|
+
draws=n_draws,
|
|
195
|
+
model=model,
|
|
196
|
+
vectorize_draws=False,
|
|
197
|
+
return_unconstrained=include_transformed,
|
|
198
|
+
random_seed=random_seed,
|
|
205
199
|
)
|
|
200
|
+
idata = InferenceData(posterior=posterior)
|
|
201
|
+
if include_transformed:
|
|
202
|
+
idata.add_groups(unconstrained_posterior=unconstrained_posterior)
|
|
206
203
|
|
|
207
204
|
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
|
|
208
205
|
var_name_to_model_var.update(
|
|
@@ -283,69 +280,3 @@ def create_dadvi_graph(
|
|
|
283
280
|
objective = -mean_log_density - entropy
|
|
284
281
|
|
|
285
282
|
return var_params, objective
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
def dadvi_result_to_idata(
|
|
289
|
-
unstacked_draws: xarray.Dataset,
|
|
290
|
-
model: Model,
|
|
291
|
-
include_transformed: bool = False,
|
|
292
|
-
progressbar: bool = True,
|
|
293
|
-
):
|
|
294
|
-
"""
|
|
295
|
-
Transforms the unconstrained draws back into the constrained space.
|
|
296
|
-
|
|
297
|
-
Parameters
|
|
298
|
-
----------
|
|
299
|
-
unstacked_draws : xarray.Dataset
|
|
300
|
-
The draws to constrain back into the original space.
|
|
301
|
-
|
|
302
|
-
model : Model
|
|
303
|
-
The PyMC model the variables were derived from.
|
|
304
|
-
|
|
305
|
-
n_draws: int
|
|
306
|
-
The number of draws to return from the variational approximation.
|
|
307
|
-
|
|
308
|
-
include_transformed: bool
|
|
309
|
-
Whether or not to keep the unconstrained variables in the output.
|
|
310
|
-
|
|
311
|
-
progressbar: bool
|
|
312
|
-
Whether or not to show a progress bar during the transformation. Default is True.
|
|
313
|
-
|
|
314
|
-
Returns
|
|
315
|
-
-------
|
|
316
|
-
:class:`~arviz.InferenceData`
|
|
317
|
-
Draws from the original constrained parameters.
|
|
318
|
-
"""
|
|
319
|
-
|
|
320
|
-
filtered_var_names = model.unobserved_value_vars
|
|
321
|
-
vars_to_sample = list(
|
|
322
|
-
get_default_varnames(filtered_var_names, include_transformed=include_transformed)
|
|
323
|
-
)
|
|
324
|
-
fn = pytensor.function(model.value_vars, vars_to_sample)
|
|
325
|
-
point_func = PointFunc(fn)
|
|
326
|
-
|
|
327
|
-
coords, dims = coords_and_dims_for_inferencedata(model)
|
|
328
|
-
|
|
329
|
-
transformed_result = apply_function_over_dataset(
|
|
330
|
-
point_func,
|
|
331
|
-
unstacked_draws,
|
|
332
|
-
output_var_names=[x.name for x in vars_to_sample],
|
|
333
|
-
coords=coords,
|
|
334
|
-
dims=dims,
|
|
335
|
-
progressbar=progressbar,
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
constrained_names = [
|
|
339
|
-
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
|
|
340
|
-
]
|
|
341
|
-
all_varnames = [
|
|
342
|
-
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
343
|
-
]
|
|
344
|
-
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
|
|
345
|
-
|
|
346
|
-
idata = az.InferenceData(posterior=transformed_result[constrained_names])
|
|
347
|
-
|
|
348
|
-
if unconstrained_names and include_transformed:
|
|
349
|
-
idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
|
|
350
|
-
|
|
351
|
-
return idata
|
|
@@ -16,9 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from collections.abc import Callable
|
|
19
|
-
from functools import partial
|
|
20
19
|
from typing import Literal
|
|
21
|
-
from typing import cast as type_cast
|
|
22
20
|
|
|
23
21
|
import arviz as az
|
|
24
22
|
import numpy as np
|
|
@@ -27,16 +25,18 @@ import pytensor
|
|
|
27
25
|
import pytensor.tensor as pt
|
|
28
26
|
import xarray as xr
|
|
29
27
|
|
|
28
|
+
from arviz import dict_to_dataset
|
|
30
29
|
from better_optimize.constants import minimize_method
|
|
31
30
|
from numpy.typing import ArrayLike
|
|
31
|
+
from pymc import Model
|
|
32
|
+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
|
|
32
33
|
from pymc.blocking import DictToArrayBijection
|
|
33
34
|
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
34
|
-
from pymc.
|
|
35
|
-
from pymc.util import get_default_varnames
|
|
35
|
+
from pymc.util import get_untransformed_name, is_transformed_name
|
|
36
36
|
from pytensor.graph import vectorize_graph
|
|
37
37
|
from pytensor.tensor import TensorVariable
|
|
38
38
|
from pytensor.tensor.optimize import minimize
|
|
39
|
-
from
|
|
39
|
+
from xarray import Dataset
|
|
40
40
|
|
|
41
41
|
from pymc_extras.inference.laplace_approx.find_map import (
|
|
42
42
|
_compute_inverse_hessian,
|
|
@@ -137,7 +137,7 @@ def get_conditional_gaussian_approximation(
|
|
|
137
137
|
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
|
|
138
138
|
|
|
139
139
|
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
|
|
140
|
-
_, logdetQ = pt.
|
|
140
|
+
_, logdetQ = pt.linalg.slogdet(Q)
|
|
141
141
|
conditional_gaussian_approx = (
|
|
142
142
|
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
|
|
143
143
|
)
|
|
@@ -147,138 +147,175 @@ def get_conditional_gaussian_approximation(
|
|
|
147
147
|
return pytensor.function(args, [x0, conditional_gaussian_approx])
|
|
148
148
|
|
|
149
149
|
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
]
|
|
155
|
-
names = [x.name for x in outputs]
|
|
150
|
+
def unpack_last_axis(packed_input, packed_shapes):
|
|
151
|
+
if len(packed_shapes) == 1:
|
|
152
|
+
# Single case currently fails in unpack
|
|
153
|
+
return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
|
|
156
154
|
|
|
157
|
-
|
|
155
|
+
keep_axes = tuple(range(packed_input.ndim))[:-1]
|
|
156
|
+
return pt.unpack(packed_input, keep_axes=keep_axes, packed_shapes=packed_shapes)
|
|
158
157
|
|
|
159
|
-
new_outputs, unconstrained_vector = join_nonshared_inputs(
|
|
160
|
-
model.initial_point(),
|
|
161
|
-
inputs=model.value_vars,
|
|
162
|
-
outputs=outputs,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
|
|
166
|
-
value_rvs = [x for x in new_outputs if x not in constrained_rvs]
|
|
167
|
-
|
|
168
|
-
unconstrained_vector.name = "unconstrained_vector"
|
|
169
158
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
159
|
+
def draws_from_laplace_approx(
|
|
160
|
+
*,
|
|
161
|
+
mean,
|
|
162
|
+
covariance=None,
|
|
163
|
+
standard_deviation=None,
|
|
164
|
+
draws: int,
|
|
165
|
+
model: Model,
|
|
166
|
+
vectorize_draws: bool = True,
|
|
167
|
+
return_unconstrained: bool = True,
|
|
168
|
+
random_seed=None,
|
|
169
|
+
compile_kwargs: dict | None = None,
|
|
170
|
+
) -> tuple[Dataset, Dataset | None]:
|
|
171
|
+
"""
|
|
172
|
+
Generate draws from the Laplace approximation of the posterior.
|
|
176
173
|
|
|
177
|
-
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
mean : np.ndarray
|
|
177
|
+
The mean of the Laplace approximation (MAP estimate).
|
|
178
|
+
covariance : np.ndarray, optional
|
|
179
|
+
The covariance matrix of the Laplace approximation.
|
|
180
|
+
Mutually exclusive with `standard_deviation`.
|
|
181
|
+
standard_deviation : np.ndarray, optional
|
|
182
|
+
The standard deviation of the Laplace approximation (diagonal approximation).
|
|
183
|
+
Mutually exclusive with `covariance`.
|
|
184
|
+
draws : int
|
|
185
|
+
The number of draws.
|
|
186
|
+
model : pm.Model
|
|
187
|
+
The PyMC model.
|
|
188
|
+
vectorize_draws : bool, default True
|
|
189
|
+
Whether to vectorize the draws.
|
|
190
|
+
return_unconstrained : bool, default True
|
|
191
|
+
Whether to return the unconstrained draws in addition to the constrained ones.
|
|
192
|
+
random_seed : int, optional
|
|
193
|
+
Random seed for reproducibility.
|
|
194
|
+
compile_kwargs: dict, optional
|
|
195
|
+
Optional compile kwargs
|
|
178
196
|
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
tuple[Dataset, Dataset | None]
|
|
200
|
+
A tuple containing the constrained draws (trace) and optionally the unconstrained draws.
|
|
201
|
+
|
|
202
|
+
Raises
|
|
203
|
+
------
|
|
204
|
+
ValueError
|
|
205
|
+
If neither `covariance` nor `standard_deviation` is provided,
|
|
206
|
+
or if both are provided.
|
|
207
|
+
"""
|
|
208
|
+
# This function assumes that mean/covariance/standard_deviation are aligned with model.initial_point()
|
|
209
|
+
if covariance is None and standard_deviation is None:
|
|
210
|
+
raise ValueError("Must specify either covariance or standard_deviation")
|
|
211
|
+
if covariance is not None and standard_deviation is not None:
|
|
212
|
+
raise ValueError("Cannot specify both covariance and standard_deviation")
|
|
213
|
+
if compile_kwargs is None:
|
|
214
|
+
compile_kwargs = {}
|
|
179
215
|
|
|
180
|
-
def model_to_laplace_approx(
|
|
181
|
-
model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
|
|
182
|
-
):
|
|
183
216
|
initial_point = model.initial_point()
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
217
|
+
n = int(np.sum([np.prod(v.shape) for v in initial_point.values()]))
|
|
218
|
+
assert mean.shape == (n,)
|
|
219
|
+
if covariance is not None:
|
|
220
|
+
assert covariance.shape == (n, n)
|
|
221
|
+
elif standard_deviation is not None:
|
|
222
|
+
assert standard_deviation.shape == (n,)
|
|
223
|
+
|
|
224
|
+
vars_to_sample = [v for v in model.free_RVs + model.deterministics]
|
|
225
|
+
var_names = [v.name for v in vars_to_sample]
|
|
226
|
+
|
|
227
|
+
orig_constrained_vars = model.value_vars
|
|
228
|
+
orig_outputs = model.replace_rvs_by_values(vars_to_sample)
|
|
229
|
+
if return_unconstrained:
|
|
230
|
+
orig_outputs.extend(model.value_vars)
|
|
231
|
+
|
|
232
|
+
mu_pt = pt.vector("mu", shape=(n,), dtype=mean.dtype)
|
|
233
|
+
size = (draws,) if vectorize_draws else ()
|
|
234
|
+
if covariance is not None:
|
|
235
|
+
sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
|
|
236
|
+
laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
|
|
237
|
+
else:
|
|
238
|
+
sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
|
|
239
|
+
laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
|
|
240
|
+
|
|
241
|
+
constrained_vars = unpack_last_axis(
|
|
242
|
+
laplace_approximation,
|
|
243
|
+
[initial_point[v.name].shape for v in orig_constrained_vars],
|
|
244
|
+
)
|
|
245
|
+
outputs = vectorize_graph(
|
|
246
|
+
orig_outputs, replace=dict(zip(orig_constrained_vars, constrained_vars))
|
|
194
247
|
)
|
|
195
248
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
dims = batch_dims
|
|
224
|
-
elif name in model.named_vars_to_dims:
|
|
225
|
-
dims = (*batch_dims, *model.named_vars_to_dims[name])
|
|
226
|
-
else:
|
|
227
|
-
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
|
|
228
|
-
initval = initial_point.get(name, None)
|
|
229
|
-
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
|
|
230
|
-
laplace_model.add_coords(
|
|
231
|
-
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
pm.Deterministic(name, batched_rv, dims=dims)
|
|
235
|
-
|
|
236
|
-
return laplace_model
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
|
|
240
|
-
"""
|
|
241
|
-
The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
|
|
242
|
-
in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
|
|
243
|
-
single vector, it's not easy to work with.
|
|
244
|
-
|
|
245
|
-
This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
|
|
246
|
-
coordinates, where possible.
|
|
247
|
-
"""
|
|
248
|
-
initial_point = DictToArrayBijection.map(model.initial_point())
|
|
249
|
-
|
|
250
|
-
cursor = 0
|
|
251
|
-
unstacked_laplace_draws = {}
|
|
252
|
-
coords = model.coords | {"chain": range(chains), "draw": range(draws)}
|
|
253
|
-
|
|
254
|
-
# There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
|
|
255
|
-
# simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
|
|
256
|
-
# add an arviz-style default dim and label.
|
|
257
|
-
for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
|
|
258
|
-
rv_dims = []
|
|
259
|
-
for i, dim in enumerate(
|
|
260
|
-
model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
|
|
261
|
-
):
|
|
262
|
-
if coords.get(dim) and shape[i] == len(coords[dim]):
|
|
263
|
-
rv_dims.append(dim)
|
|
264
|
-
else:
|
|
265
|
-
rv_dims.append(f"{name}_dim_{i}")
|
|
266
|
-
coords[f"{name}_dim_{i}"] = np.arange(shape[i])
|
|
267
|
-
|
|
268
|
-
dims = ("chain", "draw", *rv_dims)
|
|
269
|
-
|
|
270
|
-
values = (
|
|
271
|
-
laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
|
|
249
|
+
fn = pm.pytensorf.compile(
|
|
250
|
+
[mu_pt, sigma_pt],
|
|
251
|
+
outputs,
|
|
252
|
+
random_seed=random_seed,
|
|
253
|
+
trust_input=True,
|
|
254
|
+
**compile_kwargs,
|
|
255
|
+
)
|
|
256
|
+
sigma = covariance if covariance is not None else standard_deviation
|
|
257
|
+
if vectorize_draws:
|
|
258
|
+
output_buffers = fn(mean, sigma)
|
|
259
|
+
else:
|
|
260
|
+
# Take one draw to find the shape of the outputs
|
|
261
|
+
output_buffers = []
|
|
262
|
+
for out_draw in fn(mean, sigma):
|
|
263
|
+
output_buffer = np.empty((draws, *out_draw.shape), dtype=out_draw.dtype)
|
|
264
|
+
output_buffer[0] = out_draw
|
|
265
|
+
output_buffers.append(output_buffer)
|
|
266
|
+
# Fill one draws at a time
|
|
267
|
+
for i in range(1, draws):
|
|
268
|
+
for out_buffer, out_draw in zip(output_buffers, fn(mean, sigma)):
|
|
269
|
+
out_buffer[i] = out_draw
|
|
270
|
+
|
|
271
|
+
model_coords, model_dims = coords_and_dims_for_inferencedata(model)
|
|
272
|
+
posterior = {
|
|
273
|
+
var_name: out_buffer[None]
|
|
274
|
+
for var_name, out_buffer in (
|
|
275
|
+
zip(var_names, output_buffers, strict=not return_unconstrained)
|
|
272
276
|
)
|
|
273
|
-
|
|
274
|
-
|
|
277
|
+
}
|
|
278
|
+
posterior_dataset = dict_to_dataset(posterior, coords=model_coords, dims=model_dims, library=pm)
|
|
279
|
+
unconstrained_posterior_dataset = None
|
|
280
|
+
|
|
281
|
+
if return_unconstrained:
|
|
282
|
+
unconstrained_posterior = {
|
|
283
|
+
var.name: out_buffer[None]
|
|
284
|
+
for var, out_buffer in zip(
|
|
285
|
+
model.value_vars, output_buffers[len(posterior) :], strict=True
|
|
286
|
+
)
|
|
287
|
+
}
|
|
288
|
+
# Attempt to map constrained dims to unconstrained dims
|
|
289
|
+
for var_name, var_draws in unconstrained_posterior.items():
|
|
290
|
+
if not is_transformed_name(var_name):
|
|
291
|
+
# constrained == unconstrained, dims already shared
|
|
292
|
+
continue
|
|
293
|
+
constrained_dims = model_dims.get(get_untransformed_name(var_name))
|
|
294
|
+
if constrained_dims is None or (len(constrained_dims) != (var_draws.ndim - 2)):
|
|
295
|
+
continue
|
|
296
|
+
# Reuse dims from constrained variable if they match in length with unconstrained draws
|
|
297
|
+
inferred_dims = []
|
|
298
|
+
for i, (constrained_dim, unconstrained_dim_length) in enumerate(
|
|
299
|
+
zip(constrained_dims, var_draws.shape[2:], strict=True)
|
|
300
|
+
):
|
|
301
|
+
if model_coords.get(constrained_dim) is not None and (
|
|
302
|
+
len(model_coords[constrained_dim]) == unconstrained_dim_length
|
|
303
|
+
):
|
|
304
|
+
# Assume coordinates map. This could be fooled, by e.g., having a transform that reverses values
|
|
305
|
+
inferred_dims.append(constrained_dim)
|
|
306
|
+
else:
|
|
307
|
+
# Size mismatch (e.g., Simplex), make no assumption about mapping
|
|
308
|
+
inferred_dims.append(f"{var_name}_dim_{i}")
|
|
309
|
+
model_dims[var_name] = inferred_dims
|
|
310
|
+
|
|
311
|
+
unconstrained_posterior_dataset = dict_to_dataset(
|
|
312
|
+
unconstrained_posterior,
|
|
313
|
+
coords=model_coords,
|
|
314
|
+
dims=model_dims,
|
|
315
|
+
library=pm,
|
|
275
316
|
)
|
|
276
317
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
|
|
280
|
-
|
|
281
|
-
return unstacked_laplace_draws
|
|
318
|
+
return posterior_dataset, unconstrained_posterior_dataset
|
|
282
319
|
|
|
283
320
|
|
|
284
321
|
def fit_laplace(
|
|
@@ -295,8 +332,9 @@ def fit_laplace(
|
|
|
295
332
|
include_transformed: bool = True,
|
|
296
333
|
freeze_model: bool = True,
|
|
297
334
|
gradient_backend: GradientBackend = "pytensor",
|
|
298
|
-
chains: int =
|
|
335
|
+
chains: None | int = None,
|
|
299
336
|
draws: int = 500,
|
|
337
|
+
vectorize_draws: bool = True,
|
|
300
338
|
optimizer_kwargs: dict | None = None,
|
|
301
339
|
compile_kwargs: dict | None = None,
|
|
302
340
|
) -> az.InferenceData:
|
|
@@ -343,16 +381,14 @@ def fit_laplace(
|
|
|
343
381
|
True.
|
|
344
382
|
gradient_backend: str, default "pytensor"
|
|
345
383
|
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
346
|
-
chains: int, default: 2
|
|
347
|
-
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
348
|
-
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
349
|
-
compatible with the ArviZ library.
|
|
350
384
|
draws: int, default: 500
|
|
351
|
-
The number of samples to draw from the approximated posterior.
|
|
385
|
+
The number of samples to draw from the approximated posterior.
|
|
352
386
|
optimizer_kwargs
|
|
353
387
|
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
354
388
|
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
355
389
|
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
390
|
+
vectorize_draws: bool, default True
|
|
391
|
+
Whether to natively vectorize the random function or take one at a time in a python loop.
|
|
356
392
|
compile_kwargs: dict, optional
|
|
357
393
|
Additional keyword arguments to pass to pytensor.function.
|
|
358
394
|
|
|
@@ -385,6 +421,12 @@ def fit_laplace(
|
|
|
385
421
|
will forward the call to 'fit_laplace'.
|
|
386
422
|
|
|
387
423
|
"""
|
|
424
|
+
if chains is not None:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
"chains argument has been deprecated. "
|
|
427
|
+
"The behavior can be recreated by unstacking draws into multiple chains after fitting"
|
|
428
|
+
)
|
|
429
|
+
|
|
388
430
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
389
431
|
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
|
|
390
432
|
model = pm.modelcontext(model) if model is None else model
|
|
@@ -410,11 +452,10 @@ def fit_laplace(
|
|
|
410
452
|
**optimizer_kwargs,
|
|
411
453
|
)
|
|
412
454
|
|
|
413
|
-
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
414
|
-
|
|
415
455
|
if "covariance_matrix" not in idata.fit:
|
|
416
456
|
# The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
|
|
417
457
|
# we have to go back and compute the Hessian at the MAP point now.
|
|
458
|
+
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
418
459
|
frozen_model = freeze_dims_and_data(model)
|
|
419
460
|
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
|
|
420
461
|
|
|
@@ -443,29 +484,16 @@ def fit_laplace(
|
|
|
443
484
|
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
|
|
444
485
|
)
|
|
445
486
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
.drop_vars(["chain", "draw"])
|
|
459
|
-
.rename({"temp_chain": "chain", "temp_draw": "draw"})
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
if include_transformed:
|
|
463
|
-
idata.unconstrained_posterior = unstack_laplace_draws(
|
|
464
|
-
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
|
|
465
|
-
)
|
|
466
|
-
|
|
467
|
-
idata.posterior = new_posterior.drop_vars(
|
|
468
|
-
["laplace_approximation", "unpacked_variable_names"]
|
|
469
|
-
)
|
|
470
|
-
|
|
487
|
+
# We override the posterior/unconstrained_posterior from find_MAP
|
|
488
|
+
idata.posterior, unconstrained_posterior = draws_from_laplace_approx(
|
|
489
|
+
mean=idata.fit["mean_vector"].values,
|
|
490
|
+
covariance=idata.fit["covariance_matrix"].values,
|
|
491
|
+
draws=draws,
|
|
492
|
+
return_unconstrained=include_transformed,
|
|
493
|
+
model=model,
|
|
494
|
+
vectorize_draws=vectorize_draws,
|
|
495
|
+
random_seed=random_seed,
|
|
496
|
+
)
|
|
497
|
+
if include_transformed:
|
|
498
|
+
idata.unconstrained_posterior = unconstrained_posterior
|
|
471
499
|
return idata
|