pymc-extras 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/distributions/timeseries.py +10 -10
- pymc_extras/inference/dadvi/dadvi.py +14 -83
- pymc_extras/inference/laplace_approx/laplace.py +186 -158
- pymc_extras/inference/pathfinder/pathfinder.py +11 -6
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/statespace/core/statespace.py +2 -1
- pymc_extras/statespace/filters/distributions.py +15 -13
- pymc_extras/statespace/filters/kalman_filter.py +12 -11
- pymc_extras/statespace/filters/kalman_smoother.py +2 -2
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +15 -15
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -196,21 +196,20 @@ class DiscreteMarkovChain(Distribution):
|
|
|
196
196
|
state_rng = pytensor.shared(np.random.default_rng())
|
|
197
197
|
|
|
198
198
|
def transition(*args):
|
|
199
|
-
*states, transition_probs
|
|
199
|
+
old_rng, *states, transition_probs = args
|
|
200
200
|
p = transition_probs[tuple(states)]
|
|
201
201
|
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
|
|
202
|
-
return
|
|
202
|
+
return next_rng, next_state
|
|
203
203
|
|
|
204
|
-
|
|
204
|
+
state_next_rng, markov_chain = pytensor.scan(
|
|
205
205
|
transition,
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
|
|
207
|
+
non_sequences=[P_],
|
|
208
208
|
n_steps=steps_,
|
|
209
209
|
strict=True,
|
|
210
|
+
return_updates=False,
|
|
210
211
|
)
|
|
211
212
|
|
|
212
|
-
(state_next_rng,) = tuple(state_updates.values())
|
|
213
|
-
|
|
214
213
|
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
|
|
215
214
|
|
|
216
215
|
discrete_mc_op = DiscreteMarkovChainRV(
|
|
@@ -239,16 +238,17 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
|
|
|
239
238
|
n_lags = op.n_lags
|
|
240
239
|
|
|
241
240
|
def greedy_transition(*args):
|
|
242
|
-
*states, transition_probs
|
|
241
|
+
*states, transition_probs = args
|
|
243
242
|
p = transition_probs[tuple(states)]
|
|
244
243
|
return pt.argmax(p)
|
|
245
244
|
|
|
246
|
-
chain_moment
|
|
245
|
+
chain_moment = pytensor.scan(
|
|
247
246
|
greedy_transition,
|
|
248
|
-
non_sequences=[P
|
|
247
|
+
non_sequences=[P],
|
|
249
248
|
outputs_info=_make_outputs_info(n_lags, init_dist),
|
|
250
249
|
n_steps=steps,
|
|
251
250
|
strict=True,
|
|
251
|
+
return_updates=False,
|
|
252
252
|
)
|
|
253
253
|
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
|
|
254
254
|
return chain_moment
|
|
@@ -3,25 +3,20 @@ import numpy as np
|
|
|
3
3
|
import pymc
|
|
4
4
|
import pytensor
|
|
5
5
|
import pytensor.tensor as pt
|
|
6
|
-
import xarray
|
|
7
6
|
|
|
7
|
+
from arviz import InferenceData
|
|
8
8
|
from better_optimize import basinhopping, minimize
|
|
9
9
|
from better_optimize.constants import minimize_method
|
|
10
10
|
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
|
|
11
|
-
from pymc.backends.arviz import (
|
|
12
|
-
PointFunc,
|
|
13
|
-
apply_function_over_dataset,
|
|
14
|
-
coords_and_dims_for_inferencedata,
|
|
15
|
-
)
|
|
16
11
|
from pymc.blocking import RaveledVars
|
|
17
|
-
from pymc.util import RandomSeed
|
|
12
|
+
from pymc.util import RandomSeed
|
|
18
13
|
from pytensor.tensor.variable import TensorVariable
|
|
19
14
|
|
|
20
15
|
from pymc_extras.inference.laplace_approx.idata import (
|
|
21
16
|
add_data_to_inference_data,
|
|
22
17
|
add_optimizer_result_to_inference_data,
|
|
23
18
|
)
|
|
24
|
-
from pymc_extras.inference.laplace_approx.laplace import
|
|
19
|
+
from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
|
|
25
20
|
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
26
21
|
scipy_optimize_funcs_from_loss,
|
|
27
22
|
set_optimizer_function_defaults,
|
|
@@ -193,16 +188,18 @@ def fit_dadvi(
|
|
|
193
188
|
opt_var_params = result.x
|
|
194
189
|
opt_means, opt_log_sds = np.split(opt_var_params, 2)
|
|
195
190
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
|
|
191
|
+
posterior, unconstrained_posterior = draws_from_laplace_approx(
|
|
192
|
+
mean=opt_means,
|
|
193
|
+
standard_deviation=np.exp(opt_log_sds),
|
|
194
|
+
draws=n_draws,
|
|
195
|
+
model=model,
|
|
196
|
+
vectorize_draws=False,
|
|
197
|
+
return_unconstrained=include_transformed,
|
|
198
|
+
random_seed=random_seed,
|
|
205
199
|
)
|
|
200
|
+
idata = InferenceData(posterior=posterior)
|
|
201
|
+
if include_transformed:
|
|
202
|
+
idata.add_groups(unconstrained_posterior=unconstrained_posterior)
|
|
206
203
|
|
|
207
204
|
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
|
|
208
205
|
var_name_to_model_var.update(
|
|
@@ -283,69 +280,3 @@ def create_dadvi_graph(
|
|
|
283
280
|
objective = -mean_log_density - entropy
|
|
284
281
|
|
|
285
282
|
return var_params, objective
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
def dadvi_result_to_idata(
|
|
289
|
-
unstacked_draws: xarray.Dataset,
|
|
290
|
-
model: Model,
|
|
291
|
-
include_transformed: bool = False,
|
|
292
|
-
progressbar: bool = True,
|
|
293
|
-
):
|
|
294
|
-
"""
|
|
295
|
-
Transforms the unconstrained draws back into the constrained space.
|
|
296
|
-
|
|
297
|
-
Parameters
|
|
298
|
-
----------
|
|
299
|
-
unstacked_draws : xarray.Dataset
|
|
300
|
-
The draws to constrain back into the original space.
|
|
301
|
-
|
|
302
|
-
model : Model
|
|
303
|
-
The PyMC model the variables were derived from.
|
|
304
|
-
|
|
305
|
-
n_draws: int
|
|
306
|
-
The number of draws to return from the variational approximation.
|
|
307
|
-
|
|
308
|
-
include_transformed: bool
|
|
309
|
-
Whether or not to keep the unconstrained variables in the output.
|
|
310
|
-
|
|
311
|
-
progressbar: bool
|
|
312
|
-
Whether or not to show a progress bar during the transformation. Default is True.
|
|
313
|
-
|
|
314
|
-
Returns
|
|
315
|
-
-------
|
|
316
|
-
:class:`~arviz.InferenceData`
|
|
317
|
-
Draws from the original constrained parameters.
|
|
318
|
-
"""
|
|
319
|
-
|
|
320
|
-
filtered_var_names = model.unobserved_value_vars
|
|
321
|
-
vars_to_sample = list(
|
|
322
|
-
get_default_varnames(filtered_var_names, include_transformed=include_transformed)
|
|
323
|
-
)
|
|
324
|
-
fn = pytensor.function(model.value_vars, vars_to_sample)
|
|
325
|
-
point_func = PointFunc(fn)
|
|
326
|
-
|
|
327
|
-
coords, dims = coords_and_dims_for_inferencedata(model)
|
|
328
|
-
|
|
329
|
-
transformed_result = apply_function_over_dataset(
|
|
330
|
-
point_func,
|
|
331
|
-
unstacked_draws,
|
|
332
|
-
output_var_names=[x.name for x in vars_to_sample],
|
|
333
|
-
coords=coords,
|
|
334
|
-
dims=dims,
|
|
335
|
-
progressbar=progressbar,
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
constrained_names = [
|
|
339
|
-
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
|
|
340
|
-
]
|
|
341
|
-
all_varnames = [
|
|
342
|
-
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
|
|
343
|
-
]
|
|
344
|
-
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
|
|
345
|
-
|
|
346
|
-
idata = az.InferenceData(posterior=transformed_result[constrained_names])
|
|
347
|
-
|
|
348
|
-
if unconstrained_names and include_transformed:
|
|
349
|
-
idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
|
|
350
|
-
|
|
351
|
-
return idata
|
|
@@ -16,9 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from collections.abc import Callable
|
|
19
|
-
from functools import partial
|
|
20
19
|
from typing import Literal
|
|
21
|
-
from typing import cast as type_cast
|
|
22
20
|
|
|
23
21
|
import arviz as az
|
|
24
22
|
import numpy as np
|
|
@@ -27,16 +25,18 @@ import pytensor
|
|
|
27
25
|
import pytensor.tensor as pt
|
|
28
26
|
import xarray as xr
|
|
29
27
|
|
|
28
|
+
from arviz import dict_to_dataset
|
|
30
29
|
from better_optimize.constants import minimize_method
|
|
31
30
|
from numpy.typing import ArrayLike
|
|
31
|
+
from pymc import Model
|
|
32
|
+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
|
|
32
33
|
from pymc.blocking import DictToArrayBijection
|
|
33
34
|
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
34
|
-
from pymc.
|
|
35
|
-
from pymc.util import get_default_varnames
|
|
35
|
+
from pymc.util import get_untransformed_name, is_transformed_name
|
|
36
36
|
from pytensor.graph import vectorize_graph
|
|
37
37
|
from pytensor.tensor import TensorVariable
|
|
38
38
|
from pytensor.tensor.optimize import minimize
|
|
39
|
-
from
|
|
39
|
+
from xarray import Dataset
|
|
40
40
|
|
|
41
41
|
from pymc_extras.inference.laplace_approx.find_map import (
|
|
42
42
|
_compute_inverse_hessian,
|
|
@@ -147,138 +147,175 @@ def get_conditional_gaussian_approximation(
|
|
|
147
147
|
return pytensor.function(args, [x0, conditional_gaussian_approx])
|
|
148
148
|
|
|
149
149
|
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
]
|
|
155
|
-
names = [x.name for x in outputs]
|
|
150
|
+
def unpack_last_axis(packed_input, packed_shapes):
|
|
151
|
+
if len(packed_shapes) == 1:
|
|
152
|
+
# Single case currently fails in unpack
|
|
153
|
+
return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
|
|
156
154
|
|
|
157
|
-
|
|
155
|
+
keep_axes = tuple(range(packed_input.ndim))[:-1]
|
|
156
|
+
return pt.unpack(packed_input, axes=keep_axes, packed_shapes=packed_shapes)
|
|
158
157
|
|
|
159
|
-
new_outputs, unconstrained_vector = join_nonshared_inputs(
|
|
160
|
-
model.initial_point(),
|
|
161
|
-
inputs=model.value_vars,
|
|
162
|
-
outputs=outputs,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
|
|
166
|
-
value_rvs = [x for x in new_outputs if x not in constrained_rvs]
|
|
167
|
-
|
|
168
|
-
unconstrained_vector.name = "unconstrained_vector"
|
|
169
158
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
159
|
+
def draws_from_laplace_approx(
|
|
160
|
+
*,
|
|
161
|
+
mean,
|
|
162
|
+
covariance=None,
|
|
163
|
+
standard_deviation=None,
|
|
164
|
+
draws: int,
|
|
165
|
+
model: Model,
|
|
166
|
+
vectorize_draws: bool = True,
|
|
167
|
+
return_unconstrained: bool = True,
|
|
168
|
+
random_seed=None,
|
|
169
|
+
compile_kwargs: dict | None = None,
|
|
170
|
+
) -> tuple[Dataset, Dataset | None]:
|
|
171
|
+
"""
|
|
172
|
+
Generate draws from the Laplace approximation of the posterior.
|
|
176
173
|
|
|
177
|
-
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
mean : np.ndarray
|
|
177
|
+
The mean of the Laplace approximation (MAP estimate).
|
|
178
|
+
covariance : np.ndarray, optional
|
|
179
|
+
The covariance matrix of the Laplace approximation.
|
|
180
|
+
Mutually exclusive with `standard_deviation`.
|
|
181
|
+
standard_deviation : np.ndarray, optional
|
|
182
|
+
The standard deviation of the Laplace approximation (diagonal approximation).
|
|
183
|
+
Mutually exclusive with `covariance`.
|
|
184
|
+
draws : int
|
|
185
|
+
The number of draws.
|
|
186
|
+
model : pm.Model
|
|
187
|
+
The PyMC model.
|
|
188
|
+
vectorize_draws : bool, default True
|
|
189
|
+
Whether to vectorize the draws.
|
|
190
|
+
return_unconstrained : bool, default True
|
|
191
|
+
Whether to return the unconstrained draws in addition to the constrained ones.
|
|
192
|
+
random_seed : int, optional
|
|
193
|
+
Random seed for reproducibility.
|
|
194
|
+
compile_kwargs: dict, optional
|
|
195
|
+
Optional compile kwargs
|
|
178
196
|
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
tuple[Dataset, Dataset | None]
|
|
200
|
+
A tuple containing the constrained draws (trace) and optionally the unconstrained draws.
|
|
201
|
+
|
|
202
|
+
Raises
|
|
203
|
+
------
|
|
204
|
+
ValueError
|
|
205
|
+
If neither `covariance` nor `standard_deviation` is provided,
|
|
206
|
+
or if both are provided.
|
|
207
|
+
"""
|
|
208
|
+
# This function assumes that mean/covariance/standard_deviation are aligned with model.initial_point()
|
|
209
|
+
if covariance is None and standard_deviation is None:
|
|
210
|
+
raise ValueError("Must specify either covariance or standard_deviation")
|
|
211
|
+
if covariance is not None and standard_deviation is not None:
|
|
212
|
+
raise ValueError("Cannot specify both covariance and standard_deviation")
|
|
213
|
+
if compile_kwargs is None:
|
|
214
|
+
compile_kwargs = {}
|
|
179
215
|
|
|
180
|
-
def model_to_laplace_approx(
|
|
181
|
-
model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
|
|
182
|
-
):
|
|
183
216
|
initial_point = model.initial_point()
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
217
|
+
n = int(np.sum([np.prod(v.shape) for v in initial_point.values()]))
|
|
218
|
+
assert mean.shape == (n,)
|
|
219
|
+
if covariance is not None:
|
|
220
|
+
assert covariance.shape == (n, n)
|
|
221
|
+
elif standard_deviation is not None:
|
|
222
|
+
assert standard_deviation.shape == (n,)
|
|
223
|
+
|
|
224
|
+
vars_to_sample = [v for v in model.free_RVs + model.deterministics]
|
|
225
|
+
var_names = [v.name for v in vars_to_sample]
|
|
226
|
+
|
|
227
|
+
orig_constrained_vars = model.value_vars
|
|
228
|
+
orig_outputs = model.replace_rvs_by_values(vars_to_sample)
|
|
229
|
+
if return_unconstrained:
|
|
230
|
+
orig_outputs.extend(model.value_vars)
|
|
231
|
+
|
|
232
|
+
mu_pt = pt.vector("mu", shape=(n,), dtype=mean.dtype)
|
|
233
|
+
size = (draws,) if vectorize_draws else ()
|
|
234
|
+
if covariance is not None:
|
|
235
|
+
sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
|
|
236
|
+
laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
|
|
237
|
+
else:
|
|
238
|
+
sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
|
|
239
|
+
laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
|
|
240
|
+
|
|
241
|
+
constrained_vars = unpack_last_axis(
|
|
242
|
+
laplace_approximation,
|
|
243
|
+
[initial_point[v.name].shape for v in orig_constrained_vars],
|
|
244
|
+
)
|
|
245
|
+
outputs = vectorize_graph(
|
|
246
|
+
orig_outputs, replace=dict(zip(orig_constrained_vars, constrained_vars))
|
|
194
247
|
)
|
|
195
248
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
dims = batch_dims
|
|
224
|
-
elif name in model.named_vars_to_dims:
|
|
225
|
-
dims = (*batch_dims, *model.named_vars_to_dims[name])
|
|
226
|
-
else:
|
|
227
|
-
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
|
|
228
|
-
initval = initial_point.get(name, None)
|
|
229
|
-
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
|
|
230
|
-
laplace_model.add_coords(
|
|
231
|
-
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
pm.Deterministic(name, batched_rv, dims=dims)
|
|
235
|
-
|
|
236
|
-
return laplace_model
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
|
|
240
|
-
"""
|
|
241
|
-
The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
|
|
242
|
-
in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
|
|
243
|
-
single vector, it's not easy to work with.
|
|
244
|
-
|
|
245
|
-
This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
|
|
246
|
-
coordinates, where possible.
|
|
247
|
-
"""
|
|
248
|
-
initial_point = DictToArrayBijection.map(model.initial_point())
|
|
249
|
-
|
|
250
|
-
cursor = 0
|
|
251
|
-
unstacked_laplace_draws = {}
|
|
252
|
-
coords = model.coords | {"chain": range(chains), "draw": range(draws)}
|
|
253
|
-
|
|
254
|
-
# There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
|
|
255
|
-
# simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
|
|
256
|
-
# add an arviz-style default dim and label.
|
|
257
|
-
for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
|
|
258
|
-
rv_dims = []
|
|
259
|
-
for i, dim in enumerate(
|
|
260
|
-
model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
|
|
261
|
-
):
|
|
262
|
-
if coords.get(dim) and shape[i] == len(coords[dim]):
|
|
263
|
-
rv_dims.append(dim)
|
|
264
|
-
else:
|
|
265
|
-
rv_dims.append(f"{name}_dim_{i}")
|
|
266
|
-
coords[f"{name}_dim_{i}"] = np.arange(shape[i])
|
|
267
|
-
|
|
268
|
-
dims = ("chain", "draw", *rv_dims)
|
|
269
|
-
|
|
270
|
-
values = (
|
|
271
|
-
laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
|
|
249
|
+
fn = pm.pytensorf.compile(
|
|
250
|
+
[mu_pt, sigma_pt],
|
|
251
|
+
outputs,
|
|
252
|
+
random_seed=random_seed,
|
|
253
|
+
trust_input=True,
|
|
254
|
+
**compile_kwargs,
|
|
255
|
+
)
|
|
256
|
+
sigma = covariance if covariance is not None else standard_deviation
|
|
257
|
+
if vectorize_draws:
|
|
258
|
+
output_buffers = fn(mean, sigma)
|
|
259
|
+
else:
|
|
260
|
+
# Take one draw to find the shape of the outputs
|
|
261
|
+
output_buffers = []
|
|
262
|
+
for out_draw in fn(mean, sigma):
|
|
263
|
+
output_buffer = np.empty((draws, *out_draw.shape), dtype=out_draw.dtype)
|
|
264
|
+
output_buffer[0] = out_draw
|
|
265
|
+
output_buffers.append(output_buffer)
|
|
266
|
+
# Fill one draws at a time
|
|
267
|
+
for i in range(1, draws):
|
|
268
|
+
for out_buffer, out_draw in zip(output_buffers, fn(mean, sigma)):
|
|
269
|
+
out_buffer[i] = out_draw
|
|
270
|
+
|
|
271
|
+
model_coords, model_dims = coords_and_dims_for_inferencedata(model)
|
|
272
|
+
posterior = {
|
|
273
|
+
var_name: out_buffer[None]
|
|
274
|
+
for var_name, out_buffer in (
|
|
275
|
+
zip(var_names, output_buffers, strict=not return_unconstrained)
|
|
272
276
|
)
|
|
273
|
-
|
|
274
|
-
|
|
277
|
+
}
|
|
278
|
+
posterior_dataset = dict_to_dataset(posterior, coords=model_coords, dims=model_dims, library=pm)
|
|
279
|
+
unconstrained_posterior_dataset = None
|
|
280
|
+
|
|
281
|
+
if return_unconstrained:
|
|
282
|
+
unconstrained_posterior = {
|
|
283
|
+
var.name: out_buffer[None]
|
|
284
|
+
for var, out_buffer in zip(
|
|
285
|
+
model.value_vars, output_buffers[len(posterior) :], strict=True
|
|
286
|
+
)
|
|
287
|
+
}
|
|
288
|
+
# Attempt to map constrained dims to unconstrained dims
|
|
289
|
+
for var_name, var_draws in unconstrained_posterior.items():
|
|
290
|
+
if not is_transformed_name(var_name):
|
|
291
|
+
# constrained == unconstrained, dims already shared
|
|
292
|
+
continue
|
|
293
|
+
constrained_dims = model_dims.get(get_untransformed_name(var_name))
|
|
294
|
+
if constrained_dims is None or (len(constrained_dims) != (var_draws.ndim - 2)):
|
|
295
|
+
continue
|
|
296
|
+
# Reuse dims from constrained variable if they match in length with unconstrained draws
|
|
297
|
+
inferred_dims = []
|
|
298
|
+
for i, (constrained_dim, unconstrained_dim_length) in enumerate(
|
|
299
|
+
zip(constrained_dims, var_draws.shape[2:], strict=True)
|
|
300
|
+
):
|
|
301
|
+
if model_coords.get(constrained_dim) is not None and (
|
|
302
|
+
len(model_coords[constrained_dim]) == unconstrained_dim_length
|
|
303
|
+
):
|
|
304
|
+
# Assume coordinates map. This could be fooled, by e.g., having a transform that reverses values
|
|
305
|
+
inferred_dims.append(constrained_dim)
|
|
306
|
+
else:
|
|
307
|
+
# Size mismatch (e.g., Simplex), make no assumption about mapping
|
|
308
|
+
inferred_dims.append(f"{var_name}_dim_{i}")
|
|
309
|
+
model_dims[var_name] = inferred_dims
|
|
310
|
+
|
|
311
|
+
unconstrained_posterior_dataset = dict_to_dataset(
|
|
312
|
+
unconstrained_posterior,
|
|
313
|
+
coords=model_coords,
|
|
314
|
+
dims=model_dims,
|
|
315
|
+
library=pm,
|
|
275
316
|
)
|
|
276
317
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
|
|
280
|
-
|
|
281
|
-
return unstacked_laplace_draws
|
|
318
|
+
return posterior_dataset, unconstrained_posterior_dataset
|
|
282
319
|
|
|
283
320
|
|
|
284
321
|
def fit_laplace(
|
|
@@ -295,8 +332,9 @@ def fit_laplace(
|
|
|
295
332
|
include_transformed: bool = True,
|
|
296
333
|
freeze_model: bool = True,
|
|
297
334
|
gradient_backend: GradientBackend = "pytensor",
|
|
298
|
-
chains: int =
|
|
335
|
+
chains: None | int = None,
|
|
299
336
|
draws: int = 500,
|
|
337
|
+
vectorize_draws: bool = True,
|
|
300
338
|
optimizer_kwargs: dict | None = None,
|
|
301
339
|
compile_kwargs: dict | None = None,
|
|
302
340
|
) -> az.InferenceData:
|
|
@@ -343,16 +381,14 @@ def fit_laplace(
|
|
|
343
381
|
True.
|
|
344
382
|
gradient_backend: str, default "pytensor"
|
|
345
383
|
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
346
|
-
chains: int, default: 2
|
|
347
|
-
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
348
|
-
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
349
|
-
compatible with the ArviZ library.
|
|
350
384
|
draws: int, default: 500
|
|
351
|
-
The number of samples to draw from the approximated posterior.
|
|
385
|
+
The number of samples to draw from the approximated posterior.
|
|
352
386
|
optimizer_kwargs
|
|
353
387
|
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
354
388
|
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
355
389
|
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
390
|
+
vectorize_draws: bool, default True
|
|
391
|
+
Whether to natively vectorize the random function or take one at a time in a python loop.
|
|
356
392
|
compile_kwargs: dict, optional
|
|
357
393
|
Additional keyword arguments to pass to pytensor.function.
|
|
358
394
|
|
|
@@ -385,6 +421,12 @@ def fit_laplace(
|
|
|
385
421
|
will forward the call to 'fit_laplace'.
|
|
386
422
|
|
|
387
423
|
"""
|
|
424
|
+
if chains is not None:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
"chains argument has been deprecated. "
|
|
427
|
+
"The behavior can be recreated by unstacking draws into multiple chains after fitting"
|
|
428
|
+
)
|
|
429
|
+
|
|
388
430
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
389
431
|
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
|
|
390
432
|
model = pm.modelcontext(model) if model is None else model
|
|
@@ -410,11 +452,10 @@ def fit_laplace(
|
|
|
410
452
|
**optimizer_kwargs,
|
|
411
453
|
)
|
|
412
454
|
|
|
413
|
-
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
414
|
-
|
|
415
455
|
if "covariance_matrix" not in idata.fit:
|
|
416
456
|
# The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
|
|
417
457
|
# we have to go back and compute the Hessian at the MAP point now.
|
|
458
|
+
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
|
|
418
459
|
frozen_model = freeze_dims_and_data(model)
|
|
419
460
|
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
|
|
420
461
|
|
|
@@ -443,29 +484,16 @@ def fit_laplace(
|
|
|
443
484
|
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
|
|
444
485
|
)
|
|
445
486
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
.drop_vars(["chain", "draw"])
|
|
459
|
-
.rename({"temp_chain": "chain", "temp_draw": "draw"})
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
if include_transformed:
|
|
463
|
-
idata.unconstrained_posterior = unstack_laplace_draws(
|
|
464
|
-
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
|
|
465
|
-
)
|
|
466
|
-
|
|
467
|
-
idata.posterior = new_posterior.drop_vars(
|
|
468
|
-
["laplace_approximation", "unpacked_variable_names"]
|
|
469
|
-
)
|
|
470
|
-
|
|
487
|
+
# We override the posterior/unconstrained_posterior from find_MAP
|
|
488
|
+
idata.posterior, unconstrained_posterior = draws_from_laplace_approx(
|
|
489
|
+
mean=idata.fit["mean_vector"].values,
|
|
490
|
+
covariance=idata.fit["covariance_matrix"].values,
|
|
491
|
+
draws=draws,
|
|
492
|
+
return_unconstrained=include_transformed,
|
|
493
|
+
model=model,
|
|
494
|
+
vectorize_draws=vectorize_draws,
|
|
495
|
+
random_seed=random_seed,
|
|
496
|
+
)
|
|
497
|
+
if include_transformed:
|
|
498
|
+
idata.unconstrained_posterior = unconstrained_posterior
|
|
471
499
|
return idata
|
|
@@ -278,12 +278,13 @@ def alpha_recover(
|
|
|
278
278
|
z = pt.diff(g, axis=0)
|
|
279
279
|
alpha_l_init = pt.ones(N)
|
|
280
280
|
|
|
281
|
-
alpha
|
|
281
|
+
alpha = pytensor.scan(
|
|
282
282
|
fn=compute_alpha_l,
|
|
283
283
|
outputs_info=alpha_l_init,
|
|
284
284
|
sequences=[s, z],
|
|
285
285
|
n_steps=Lp1 - 1,
|
|
286
286
|
allow_gc=False,
|
|
287
|
+
return_updates=False,
|
|
287
288
|
)
|
|
288
289
|
|
|
289
290
|
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
|
|
@@ -334,11 +335,12 @@ def inverse_hessian_factors(
|
|
|
334
335
|
return pt.set_subtensor(chi_l[j_last], diff_l)
|
|
335
336
|
|
|
336
337
|
chi_init = pt.zeros((J, N))
|
|
337
|
-
chi_mat
|
|
338
|
+
chi_mat = pytensor.scan(
|
|
338
339
|
fn=chi_update,
|
|
339
340
|
outputs_info=chi_init,
|
|
340
341
|
sequences=[diff],
|
|
341
342
|
allow_gc=False,
|
|
343
|
+
return_updates=False,
|
|
342
344
|
)
|
|
343
345
|
|
|
344
346
|
chi_mat = pt.matrix_transpose(chi_mat)
|
|
@@ -377,14 +379,14 @@ def inverse_hessian_factors(
|
|
|
377
379
|
eta = pt.diagonal(E, axis1=-2, axis2=-1)
|
|
378
380
|
|
|
379
381
|
# beta: (L, N, 2J)
|
|
380
|
-
alpha_diag
|
|
382
|
+
alpha_diag = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha], return_updates=False)
|
|
381
383
|
beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
|
|
382
384
|
|
|
383
385
|
# more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
|
|
384
386
|
|
|
385
387
|
# E_inv: (L, J, J)
|
|
386
388
|
E_inv = pt.slinalg.solve_triangular(E, Ij, check_finite=False)
|
|
387
|
-
eta_diag
|
|
389
|
+
eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
|
|
388
390
|
|
|
389
391
|
# block_dd: (L, J, J)
|
|
390
392
|
block_dd = (
|
|
@@ -530,7 +532,9 @@ def bfgs_sample_sparse(
|
|
|
530
532
|
|
|
531
533
|
# qr_input: (L, N, 2J)
|
|
532
534
|
qr_input = inv_sqrt_alpha_diag @ beta
|
|
533
|
-
|
|
535
|
+
Q, R = pytensor.scan(
|
|
536
|
+
fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False, return_updates=False
|
|
537
|
+
)
|
|
534
538
|
|
|
535
539
|
IdN = pt.eye(R.shape[1])[None, ...]
|
|
536
540
|
IdN += IdN * REGULARISATION_TERM
|
|
@@ -623,10 +627,11 @@ def bfgs_sample(
|
|
|
623
627
|
|
|
624
628
|
L, N, JJ = beta.shape
|
|
625
629
|
|
|
626
|
-
|
|
630
|
+
alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag = pytensor.scan(
|
|
627
631
|
lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
|
|
628
632
|
sequences=[alpha],
|
|
629
633
|
allow_gc=False,
|
|
634
|
+
return_updates=False,
|
|
630
635
|
)
|
|
631
636
|
|
|
632
637
|
u = pt.random.normal(size=(L, num_samples, N))
|
|
@@ -238,7 +238,7 @@ class SMCDiagnostics(NamedTuple):
|
|
|
238
238
|
def update_diagnosis(i, history, info, state):
|
|
239
239
|
le, lli, ancestors, weights_evolution = history
|
|
240
240
|
return SMCDiagnostics(
|
|
241
|
-
le.at[i].set(state.
|
|
241
|
+
le.at[i].set(state.tempering_param),
|
|
242
242
|
lli.at[i].set(info.log_likelihood_increment),
|
|
243
243
|
ancestors.at[i].set(info.ancestors),
|
|
244
244
|
weights_evolution.at[i].set(state.weights),
|
|
@@ -265,7 +265,7 @@ def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_par
|
|
|
265
265
|
|
|
266
266
|
def cond(carry):
|
|
267
267
|
i, state, _, _ = carry
|
|
268
|
-
return state.
|
|
268
|
+
return state.tempering_param < 1
|
|
269
269
|
|
|
270
270
|
def one_step(carry):
|
|
271
271
|
i, state, k, previous_info = carry
|
|
@@ -282,11 +282,12 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
|
|
|
282
282
|
def logp_fn(marginalized_rv_const, *non_sequences):
|
|
283
283
|
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
|
|
284
284
|
|
|
285
|
-
joint_logps
|
|
285
|
+
joint_logps = scan_map(
|
|
286
286
|
fn=logp_fn,
|
|
287
287
|
sequences=marginalized_rv_domain_tensor,
|
|
288
288
|
non_sequences=[*values, *inputs],
|
|
289
289
|
mode=Mode().including("local_remove_check_parameter"),
|
|
290
|
+
return_updates=False,
|
|
290
291
|
)
|
|
291
292
|
|
|
292
293
|
joint_logp = pt.logsumexp(joint_logps, axis=0)
|
|
@@ -350,12 +351,13 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
|
|
|
350
351
|
|
|
351
352
|
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
|
|
352
353
|
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
|
|
353
|
-
log_alpha_seq
|
|
354
|
+
log_alpha_seq = scan(
|
|
354
355
|
step_alpha,
|
|
355
356
|
non_sequences=[log_P],
|
|
356
357
|
outputs_info=[log_alpha_init],
|
|
357
358
|
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
|
|
358
359
|
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
|
|
360
|
+
return_updates=False,
|
|
359
361
|
)
|
|
360
362
|
# Final logp is just the sum of the last scan state
|
|
361
363
|
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
|
|
@@ -11,7 +11,7 @@ from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_po
|
|
|
11
11
|
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
|
12
12
|
from pymc.distributions.transforms import Chain
|
|
13
13
|
from pymc.logprob.transforms import IntervalTransform
|
|
14
|
-
from pymc.model import Model
|
|
14
|
+
from pymc.model import Model, modelcontext
|
|
15
15
|
from pymc.model.fgraph import (
|
|
16
16
|
ModelFreeRV,
|
|
17
17
|
ModelValuedVar,
|
|
@@ -337,8 +337,9 @@ def transform_posterior_pts(model, posterior_pts):
|
|
|
337
337
|
|
|
338
338
|
|
|
339
339
|
def recover_marginals(
|
|
340
|
-
model: Model,
|
|
341
340
|
idata: InferenceData,
|
|
341
|
+
*,
|
|
342
|
+
model: Model | None = None,
|
|
342
343
|
var_names: Sequence[str] | None = None,
|
|
343
344
|
return_samples: bool = True,
|
|
344
345
|
extend_inferencedata: bool = True,
|
|
@@ -389,6 +390,15 @@ def recover_marginals(
|
|
|
389
390
|
|
|
390
391
|
|
|
391
392
|
"""
|
|
393
|
+
# Temporary error message for helping with migration
|
|
394
|
+
# Will be removed in a future release
|
|
395
|
+
if isinstance(idata, Model):
|
|
396
|
+
raise TypeError(
|
|
397
|
+
"The order of arguments of `recover_marginals` changed. The first input must be an idata"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
model = modelcontext(model)
|
|
401
|
+
|
|
392
402
|
unmarginal_model = unmarginalize(model)
|
|
393
403
|
|
|
394
404
|
# Find the names of the marginalized variables
|
|
@@ -2500,13 +2500,14 @@ class PyMCStateSpace:
|
|
|
2500
2500
|
next_x = c + T @ x + R @ shock
|
|
2501
2501
|
return next_x
|
|
2502
2502
|
|
|
2503
|
-
irf
|
|
2503
|
+
irf = pytensor.scan(
|
|
2504
2504
|
irf_step,
|
|
2505
2505
|
sequences=[shock_trajectory],
|
|
2506
2506
|
outputs_info=[x0],
|
|
2507
2507
|
non_sequences=[c, T, R],
|
|
2508
2508
|
n_steps=n_steps,
|
|
2509
2509
|
strict=True,
|
|
2510
|
+
return_updates=False,
|
|
2510
2511
|
)
|
|
2511
2512
|
|
|
2512
2513
|
pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM])
|
|
@@ -197,10 +197,9 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
197
197
|
n_seq = len(sequence_names)
|
|
198
198
|
|
|
199
199
|
def step_fn(*args):
|
|
200
|
-
seqs, state, non_seqs = args[:n_seq], args[n_seq
|
|
201
|
-
non_seqs, rng = non_seqs[:-1], non_seqs[-1]
|
|
200
|
+
seqs, (rng, state, *non_seqs) = args[:n_seq], args[n_seq:]
|
|
202
201
|
|
|
203
|
-
c, d, T, Z, R, H, Q = sort_args(seqs
|
|
202
|
+
c, d, T, Z, R, H, Q = sort_args((*seqs, *non_seqs))
|
|
204
203
|
k = T.shape[0]
|
|
205
204
|
a = state[:k]
|
|
206
205
|
|
|
@@ -219,7 +218,7 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
219
218
|
|
|
220
219
|
next_state = pt.concatenate([a_next, y_next], axis=0)
|
|
221
220
|
|
|
222
|
-
return
|
|
221
|
+
return next_rng, next_state
|
|
223
222
|
|
|
224
223
|
Z_init = Z_ if Z_ in non_sequences else Z_[0]
|
|
225
224
|
H_init = H_ if H_ in non_sequences else H_[0]
|
|
@@ -229,13 +228,14 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
229
228
|
|
|
230
229
|
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
|
|
231
230
|
|
|
232
|
-
|
|
231
|
+
ss_rng, statespace = pytensor.scan(
|
|
233
232
|
step_fn,
|
|
234
|
-
outputs_info=[init_dist_],
|
|
233
|
+
outputs_info=[rng, init_dist_],
|
|
235
234
|
sequences=None if len(sequences) == 0 else sequences,
|
|
236
|
-
non_sequences=[*non_sequences
|
|
235
|
+
non_sequences=[*non_sequences],
|
|
237
236
|
n_steps=steps,
|
|
238
237
|
strict=True,
|
|
238
|
+
return_updates=False,
|
|
239
239
|
)
|
|
240
240
|
|
|
241
241
|
if append_x0:
|
|
@@ -245,7 +245,6 @@ class _LinearGaussianStateSpace(Continuous):
|
|
|
245
245
|
statespace_ = statespace
|
|
246
246
|
statespace_ = pt.specify_shape(statespace_, (steps, None))
|
|
247
247
|
|
|
248
|
-
(ss_rng,) = tuple(updates.values())
|
|
249
248
|
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
|
|
250
249
|
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
|
|
251
250
|
outputs=[ss_rng, statespace_],
|
|
@@ -385,10 +384,15 @@ class SequenceMvNormal(Continuous):
|
|
|
385
384
|
|
|
386
385
|
def step(mu, cov, rng):
|
|
387
386
|
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method=method).owner.outputs
|
|
388
|
-
return
|
|
387
|
+
return new_rng, mvn
|
|
389
388
|
|
|
390
|
-
|
|
391
|
-
step,
|
|
389
|
+
seq_mvn_rng, mvn_seq = pytensor.scan(
|
|
390
|
+
step,
|
|
391
|
+
sequences=[mus_, covs_],
|
|
392
|
+
outputs_info=[rng, None],
|
|
393
|
+
strict=True,
|
|
394
|
+
n_steps=mus_.shape[0],
|
|
395
|
+
return_updates=False,
|
|
392
396
|
)
|
|
393
397
|
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
|
|
394
398
|
|
|
@@ -396,8 +400,6 @@ class SequenceMvNormal(Continuous):
|
|
|
396
400
|
if mvn_seq.ndim > 2:
|
|
397
401
|
mvn_seq = pt.moveaxis(mvn_seq, 0, -2)
|
|
398
402
|
|
|
399
|
-
(seq_mvn_rng,) = tuple(updates.values())
|
|
400
|
-
|
|
401
403
|
mvn_seq_op = KalmanFilterRV(
|
|
402
404
|
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
|
|
403
405
|
)
|
|
@@ -148,10 +148,9 @@ class BaseFilter(ABC):
|
|
|
148
148
|
R,
|
|
149
149
|
H,
|
|
150
150
|
Q,
|
|
151
|
-
return_updates=False,
|
|
152
151
|
missing_fill_value=None,
|
|
153
152
|
cov_jitter=None,
|
|
154
|
-
) -> list[TensorVariable]
|
|
153
|
+
) -> list[TensorVariable]:
|
|
155
154
|
"""
|
|
156
155
|
Construct the computation graph for the Kalman filter. See [1] for details.
|
|
157
156
|
|
|
@@ -211,20 +210,17 @@ class BaseFilter(ABC):
|
|
|
211
210
|
if len(sequences) > 0:
|
|
212
211
|
sequences = self.add_check_on_time_varying_shapes(data, sequences)
|
|
213
212
|
|
|
214
|
-
results
|
|
213
|
+
results = pytensor.scan(
|
|
215
214
|
self.kalman_step,
|
|
216
215
|
sequences=[data, *sequences],
|
|
217
216
|
outputs_info=[None, a0, None, None, P0, None, None],
|
|
218
217
|
non_sequences=non_sequences,
|
|
219
218
|
name="forward_kalman_pass",
|
|
220
219
|
strict=False,
|
|
220
|
+
return_updates=False,
|
|
221
221
|
)
|
|
222
222
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if return_updates:
|
|
226
|
-
return filter_results, updates
|
|
227
|
-
return filter_results
|
|
223
|
+
return self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0])
|
|
228
224
|
|
|
229
225
|
def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
|
|
230
226
|
"""
|
|
@@ -652,7 +648,9 @@ class SquareRootFilter(BaseFilter):
|
|
|
652
648
|
y_hat = Z.dot(a) + d
|
|
653
649
|
v = y - y_hat
|
|
654
650
|
|
|
655
|
-
H_chol = pytensor.ifelse(
|
|
651
|
+
H_chol = pytensor.ifelse(
|
|
652
|
+
pt.all(pt.eq(H, 0.0)), H, pt.linalg.cholesky(H, lower=True, on_error="nan")
|
|
653
|
+
)
|
|
656
654
|
|
|
657
655
|
# The following notation comes from https://ipnpr.jpl.nasa.gov/progress_report/42-233/42-233A.pdf
|
|
658
656
|
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
|
|
@@ -694,8 +692,10 @@ class SquareRootFilter(BaseFilter):
|
|
|
694
692
|
"""
|
|
695
693
|
return [a, P_chol, pt.zeros(())]
|
|
696
694
|
|
|
695
|
+
degenerate = pt.eq(all_nan_flag, 1.0)
|
|
696
|
+
F_chol = pytensor.ifelse(degenerate, pt.eye(*F_chol.shape), F_chol)
|
|
697
697
|
[a_filtered, P_chol_filtered, ll] = pytensor.ifelse(
|
|
698
|
-
|
|
698
|
+
degenerate,
|
|
699
699
|
compute_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
|
|
700
700
|
compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v),
|
|
701
701
|
)
|
|
@@ -786,11 +786,12 @@ class UnivariateFilter(BaseFilter):
|
|
|
786
786
|
H_masked = W.dot(H)
|
|
787
787
|
y_masked = pt.set_subtensor(y[nan_mask], 0.0)
|
|
788
788
|
|
|
789
|
-
result
|
|
789
|
+
result = pytensor.scan(
|
|
790
790
|
self._univariate_inner_filter_step,
|
|
791
791
|
sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask],
|
|
792
792
|
outputs_info=[a, P, None, None, None],
|
|
793
793
|
name="univariate_inner_scan",
|
|
794
|
+
return_updates=False,
|
|
794
795
|
)
|
|
795
796
|
|
|
796
797
|
a_filtered, P_filtered, obs_mu, obs_cov, ll_inner = result
|
|
@@ -76,16 +76,16 @@ class KalmanSmoother:
|
|
|
76
76
|
self.seq_names = seq_names
|
|
77
77
|
self.non_seq_names = non_seq_names
|
|
78
78
|
|
|
79
|
-
|
|
79
|
+
smoothed_states, smoothed_covariances = pytensor.scan(
|
|
80
80
|
self.smoother_step,
|
|
81
81
|
sequences=[filtered_states[:-1], filtered_covariances[:-1], *sequences],
|
|
82
82
|
outputs_info=[a_last, P_last],
|
|
83
83
|
non_sequences=non_sequences,
|
|
84
84
|
go_backwards=True,
|
|
85
85
|
name="kalman_smoother",
|
|
86
|
+
return_updates=False,
|
|
86
87
|
)
|
|
87
88
|
|
|
88
|
-
smoothed_states, smoothed_covariances = smoother_result
|
|
89
89
|
smoothed_states = pt.concatenate(
|
|
90
90
|
[smoothed_states[::-1], pt.expand_dims(a_last, axis=(0,))], axis=0
|
|
91
91
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Project-URL: Documentation, https://pymc-extras.readthedocs.io/
|
|
6
6
|
Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
|
|
@@ -235,8 +235,8 @@ Requires-Python: >=3.11
|
|
|
235
235
|
Requires-Dist: better-optimize>=0.1.5
|
|
236
236
|
Requires-Dist: preliz>=0.20.0
|
|
237
237
|
Requires-Dist: pydantic>=2.0.0
|
|
238
|
-
Requires-Dist: pymc>=5.
|
|
239
|
-
Requires-Dist: pytensor>=2.
|
|
238
|
+
Requires-Dist: pymc>=5.27.0
|
|
239
|
+
Requires-Dist: pytensor>=2.36.3
|
|
240
240
|
Requires-Dist: scikit-learn
|
|
241
241
|
Provides-Extra: complete
|
|
242
242
|
Requires-Dist: dask[complete]<2025.1.1; extra == 'complete'
|
|
@@ -245,7 +245,7 @@ Provides-Extra: dask-histogram
|
|
|
245
245
|
Requires-Dist: dask[complete]<2025.1.1; extra == 'dask-histogram'
|
|
246
246
|
Requires-Dist: xhistogram; extra == 'dask-histogram'
|
|
247
247
|
Provides-Extra: dev
|
|
248
|
-
Requires-Dist: blackjax; extra == 'dev'
|
|
248
|
+
Requires-Dist: blackjax>=0.12; extra == 'dev'
|
|
249
249
|
Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
|
|
250
250
|
Requires-Dist: pytest-mock; extra == 'dev'
|
|
251
251
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
@@ -8,7 +8,7 @@ pymc_extras/distributions/__init__.py,sha256=Cge3AP7gzD6qTJY7v2tYRtSgn-rlnIo7wQB
|
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=bCXOgnw2Vh_FbYOHCqB0c3ozFVay5Qwua2A211kvWNQ,11251
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=HNi-K0_hnNWTcfyBkWGh26sc71FwBgukQ_EjGAaAOjY,13036
|
|
10
10
|
pymc_extras/distributions/histogram_utils.py,sha256=kkZHu1F_2qMfOEzwNP4K6QYA_xEKUk9cChImOQ2Nkjs,5847
|
|
11
|
-
pymc_extras/distributions/timeseries.py,sha256=
|
|
11
|
+
pymc_extras/distributions/timeseries.py,sha256=WysWtUchfObTGmKduF47bUBqV_g1kW-uAx4_oKENgDg,12709
|
|
12
12
|
pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
|
|
13
13
|
pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=5SzvD41pu-EWyWlDNz4AR4Sl8MkyC-1dYwkADFh5Avg,16009
|
|
14
14
|
pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-vy5I-puV-Kz13-QtLNc,104
|
|
@@ -18,25 +18,25 @@ pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQD
|
|
|
18
18
|
pymc_extras/inference/__init__.py,sha256=hI3yqfEVzoUNlCpL1z579F9EqM-NlPTzMfHj8IKY-xE,1009
|
|
19
19
|
pymc_extras/inference/fit.py,sha256=hNTqLms_mTdjfnCEVIHMcMiPZ3fkU3HEEkbt6LWWhLw,1443
|
|
20
20
|
pymc_extras/inference/dadvi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
|
-
pymc_extras/inference/dadvi/dadvi.py,sha256=
|
|
21
|
+
pymc_extras/inference/dadvi/dadvi.py,sha256=rERiyMn1ywEerWJ8rq3WNZKtKEpX2lHAdqApatZyJpQ,9698
|
|
22
22
|
pymc_extras/inference/laplace_approx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
23
|
pymc_extras/inference/laplace_approx/find_map.py,sha256=fbK0swDsSBo7pP1TBokREa2wkK1ajL_gLVVuREHH33k,13658
|
|
24
24
|
pymc_extras/inference/laplace_approx/idata.py,sha256=Dxj6A8aJXn8c24vD_PZmMgIgrwEmaYDlbw5UAJq0Nyw,14172
|
|
25
|
-
pymc_extras/inference/laplace_approx/laplace.py,sha256=
|
|
25
|
+
pymc_extras/inference/laplace_approx/laplace.py,sha256=J4Ddt7Jc1nRZvxHYUz2CWSpvJCJOMG3p2ayyUf1T7tE,20377
|
|
26
26
|
pymc_extras/inference/laplace_approx/scipy_interface.py,sha256=Crhix_dLA8Y_NvuUDmVQnKWAWGjufmQwDLh-bK9dz_o,10235
|
|
27
27
|
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
28
28
|
pymc_extras/inference/pathfinder/idata.py,sha256=muAPc9JeI8ZmpjzSp9tSj-uNrcsoNkYb4raJqjgf5UQ,18636
|
|
29
29
|
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
|
|
30
30
|
pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
|
|
31
|
-
pymc_extras/inference/pathfinder/pathfinder.py,sha256=
|
|
31
|
+
pymc_extras/inference/pathfinder/pathfinder.py,sha256=IdKyJvGAeRstvTprKVQ4xk1hy6KjB8h-ggbmM7kMPEw,67345
|
|
32
32
|
pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
|
|
33
|
-
pymc_extras/inference/smc/sampling.py,sha256=
|
|
33
|
+
pymc_extras/inference/smc/sampling.py,sha256=eyRIFPf--tcPpuHPNCxGZNQZVd7MazR4l9aURNY87S0,15385
|
|
34
34
|
pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
35
|
pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
|
|
36
36
|
pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
-
pymc_extras/model/marginal/distributions.py,sha256=
|
|
37
|
+
pymc_extras/model/marginal/distributions.py,sha256=mf6Czm6av2nCydu6uKqjamKFPD8efWJmNlTMy4Ojrvk,15621
|
|
38
38
|
pymc_extras/model/marginal/graph_analysis.py,sha256=Ft7RZC126R0TW2GuFdgb9uN-JSgDGTeffs-UuPcDHQE,15884
|
|
39
|
-
pymc_extras/model/marginal/marginal_model.py,sha256=
|
|
39
|
+
pymc_extras/model/marginal/marginal_model.py,sha256=Wgfcq6hplACU4Kh8aKT2P_kz_yo7r0wIVTupZQtQUKw,23969
|
|
40
40
|
pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
41
|
pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
|
|
42
42
|
pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -45,11 +45,11 @@ pymc_extras/statespace/__init__.py,sha256=PxV8i4aa2XJarRM6aKU14_bEY1AoLu4bNXIBy_
|
|
|
45
45
|
pymc_extras/statespace/core/__init__.py,sha256=LEhkqdMZzzcTyzYml45IM4ykWoCdbWWj2c29IpM_ey8,309
|
|
46
46
|
pymc_extras/statespace/core/compile.py,sha256=GB2H7sE28OdQ6GmNIjtq1R1Oua2GPf6kWJ7IPuYJaNA,1607
|
|
47
47
|
pymc_extras/statespace/core/representation.py,sha256=boY-jjlkd3KuuO2XiSuV-GwEAyEqRJ9267H72AmE3BU,18956
|
|
48
|
-
pymc_extras/statespace/core/statespace.py,sha256=
|
|
48
|
+
pymc_extras/statespace/core/statespace.py,sha256=pyXIS95MJtJFzQozzdx4tNvf5-jY0Po3Z8aqaNAD7uo,108190
|
|
49
49
|
pymc_extras/statespace/filters/__init__.py,sha256=F0EtZUhArp23lj3upy6zB0mDTjLIjwGh0pKmMny0QfY,420
|
|
50
|
-
pymc_extras/statespace/filters/distributions.py,sha256
|
|
51
|
-
pymc_extras/statespace/filters/kalman_filter.py,sha256=
|
|
52
|
-
pymc_extras/statespace/filters/kalman_smoother.py,sha256=
|
|
50
|
+
pymc_extras/statespace/filters/distributions.py,sha256=uLCs3iJObHyslOPLUFJp5G9w56AWJIofiFq2KozecXc,11881
|
|
51
|
+
pymc_extras/statespace/filters/kalman_filter.py,sha256=x6J54t9cHi3tXKtCB6QW62Kyit5zjd5AnI3IFdxKtzw,31561
|
|
52
|
+
pymc_extras/statespace/filters/kalman_smoother.py,sha256=JmnvXwHVzWTdmtgECTJY0FJOFZG6O9aEfRoSTWEeU2s,4111
|
|
53
53
|
pymc_extras/statespace/filters/utilities.py,sha256=BBMDeWBcJWZfGc9owuMsOedVIXVDQ8Z2eMiU9vWeVr0,1494
|
|
54
54
|
pymc_extras/statespace/models/DFM.py,sha256=EiZ3x4iFPGeha8bPp1tok4us8Z6UVUu1sFmKIM1i0xc,36458
|
|
55
55
|
pymc_extras/statespace/models/ETS.py,sha256=LEsSKzbfm9Ol8UZQjNurcrM1CLQyozKfJtby7AzsDeI,27667
|
|
@@ -76,7 +76,7 @@ pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYk
|
|
|
76
76
|
pymc_extras/utils/model_equivalence.py,sha256=9MLwSj7VwxxKupzmEkKBbwGD1X0WM2FGcGIpfb8bViw,2197
|
|
77
77
|
pymc_extras/utils/prior.py,sha256=mnuFpamp04eQJuTU5NyB2PfCG5r-1McSmQGwQXSR_Lg,6670
|
|
78
78
|
pymc_extras/utils/spline.py,sha256=R0u3eAcV5bRmD2YSLqDm0qnaJbEuf3V38OZ7amV7-Tc,4732
|
|
79
|
-
pymc_extras-0.
|
|
80
|
-
pymc_extras-0.
|
|
81
|
-
pymc_extras-0.
|
|
82
|
-
pymc_extras-0.
|
|
79
|
+
pymc_extras-0.7.0.dist-info/METADATA,sha256=iTIf9JVSRSbvmGz-ASblyd_lR8Jj4eWrUsjgzx97QUw,18904
|
|
80
|
+
pymc_extras-0.7.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
81
|
+
pymc_extras-0.7.0.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
82
|
+
pymc_extras-0.7.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|