pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +14 -12
  6. pymc_extras/inference/dadvi/dadvi.py +149 -128
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -39
  8. pymc_extras/inference/laplace_approx/idata.py +22 -4
  9. pymc_extras/inference/laplace_approx/laplace.py +196 -151
  10. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  11. pymc_extras/inference/pathfinder/idata.py +517 -0
  12. pymc_extras/inference/pathfinder/pathfinder.py +71 -12
  13. pymc_extras/inference/smc/sampling.py +2 -2
  14. pymc_extras/model/marginal/distributions.py +4 -2
  15. pymc_extras/model/marginal/graph_analysis.py +2 -2
  16. pymc_extras/model/marginal/marginal_model.py +12 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/core/statespace.py +2 -1
  21. pymc_extras/statespace/filters/distributions.py +15 -13
  22. pymc_extras/statespace/filters/kalman_filter.py +24 -22
  23. pymc_extras/statespace/filters/kalman_smoother.py +3 -5
  24. pymc_extras/statespace/filters/utilities.py +2 -5
  25. pymc_extras/statespace/models/DFM.py +12 -27
  26. pymc_extras/statespace/models/ETS.py +190 -198
  27. pymc_extras/statespace/models/SARIMAX.py +5 -17
  28. pymc_extras/statespace/models/VARMAX.py +15 -67
  29. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  30. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  31. pymc_extras/statespace/models/utilities.py +7 -0
  32. pymc_extras/utils/model_equivalence.py +2 -2
  33. pymc_extras/utils/prior.py +10 -14
  34. pymc_extras/utils/spline.py +4 -10
  35. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
  36. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
  37. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
  38. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -22,10 +22,15 @@ def make_default_labels(name: str, shape: tuple[int, ...]) -> list:
22
22
  return [list(range(dim)) for dim in shape]
23
23
 
24
24
 
25
- def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]:
25
+ def make_unpacked_variable_names(
26
+ names: list[str], model: pm.Model, var_name_to_model_var: dict[str, str] | None = None
27
+ ) -> list[str]:
26
28
  coords = model.coords
27
29
  initial_point = model.initial_point()
28
30
 
31
+ if var_name_to_model_var is None:
32
+ var_name_to_model_var = {}
33
+
29
34
  value_to_dim = {
30
35
  value.name: model.named_vars_to_dims.get(model.values_to_rvs[value].name, None)
31
36
  for value in model.value_vars
@@ -37,6 +42,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
37
42
 
38
43
  unpacked_variable_names = []
39
44
  for name in names:
45
+ name = var_name_to_model_var.get(name, name)
40
46
  shape = initial_point[name].shape
41
47
  if shape:
42
48
  dims = dims_dict.get(name)
@@ -109,7 +115,7 @@ def map_results_to_inference_data(
109
115
  x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
110
116
  ]
111
117
 
112
- unconstrained_names = set(all_varnames) - set(constrained_names)
118
+ unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
113
119
 
114
120
  idata = az.from_dict(
115
121
  posterior={
@@ -258,6 +264,7 @@ def optimizer_result_to_dataset(
258
264
  method: minimize_method | Literal["basinhopping"],
259
265
  mu: RaveledVars | None = None,
260
266
  model: pm.Model | None = None,
267
+ var_name_to_model_var: dict[str, str] | None = None,
261
268
  ) -> xr.Dataset:
262
269
  """
263
270
  Convert an OptimizeResult object to an xarray Dataset object.
@@ -268,6 +275,9 @@ def optimizer_result_to_dataset(
268
275
  The result of the optimization process.
269
276
  method: minimize_method or "basinhopping"
270
277
  The optimization method used.
278
+ var_name_to_model_var: dict, optional
279
+ Mapping between variables in the optimization result and the model variable names. Used when auxiliary
280
+ variables were introduced, e.g. in DADVI.
271
281
 
272
282
  Returns
273
283
  -------
@@ -279,7 +289,9 @@ def optimizer_result_to_dataset(
279
289
 
280
290
  model = pm.modelcontext(model) if model is None else model
281
291
  variable_names, *_ = zip(*mu.point_map_info)
282
- unpacked_variable_names = make_unpacked_variable_names(variable_names, model)
292
+ unpacked_variable_names = make_unpacked_variable_names(
293
+ variable_names, model, var_name_to_model_var
294
+ )
283
295
 
284
296
  data_vars = {}
285
297
 
@@ -368,6 +380,7 @@ def add_optimizer_result_to_inference_data(
368
380
  method: minimize_method | Literal["basinhopping"],
369
381
  mu: RaveledVars | None = None,
370
382
  model: pm.Model | None = None,
383
+ var_name_to_model_var: dict[str, str] | None = None,
371
384
  ) -> az.InferenceData:
372
385
  """
373
386
  Add the optimization result to an InferenceData object.
@@ -384,13 +397,18 @@ def add_optimizer_result_to_inference_data(
384
397
  The MAP estimate of the model parameters.
385
398
  model: Model, optional
386
399
  A PyMC model. If None, the model is taken from the current model context.
400
+ var_name_to_model_var: dict, optional
401
+ Mapping between variables in the optimization result and the model variable names. Used when auxiliary
402
+ variables were introduced, e.g. in DADVI.
387
403
 
388
404
  Returns
389
405
  -------
390
406
  idata: az.InferenceData
391
407
  The provided InferenceData, with the optimization results added to the "optimizer" group.
392
408
  """
393
- dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model)
409
+ dataset = optimizer_result_to_dataset(
410
+ result, method=method, mu=mu, model=model, var_name_to_model_var=var_name_to_model_var
411
+ )
394
412
  idata.add_groups({"optimizer_result": dataset})
395
413
 
396
414
  return idata
@@ -16,9 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from collections.abc import Callable
19
- from functools import partial
20
19
  from typing import Literal
21
- from typing import cast as type_cast
22
20
 
23
21
  import arviz as az
24
22
  import numpy as np
@@ -27,16 +25,18 @@ import pytensor
27
25
  import pytensor.tensor as pt
28
26
  import xarray as xr
29
27
 
28
+ from arviz import dict_to_dataset
30
29
  from better_optimize.constants import minimize_method
31
30
  from numpy.typing import ArrayLike
31
+ from pymc import Model
32
+ from pymc.backends.arviz import coords_and_dims_for_inferencedata
32
33
  from pymc.blocking import DictToArrayBijection
33
34
  from pymc.model.transform.optimization import freeze_dims_and_data
34
- from pymc.pytensorf import join_nonshared_inputs
35
- from pymc.util import get_default_varnames
35
+ from pymc.util import get_untransformed_name, is_transformed_name
36
36
  from pytensor.graph import vectorize_graph
37
37
  from pytensor.tensor import TensorVariable
38
38
  from pytensor.tensor.optimize import minimize
39
- from pytensor.tensor.type import Variable
39
+ from xarray import Dataset
40
40
 
41
41
  from pymc_extras.inference.laplace_approx.find_map import (
42
42
  _compute_inverse_hessian,
@@ -147,130 +147,175 @@ def get_conditional_gaussian_approximation(
147
147
  return pytensor.function(args, [x0, conditional_gaussian_approx])
148
148
 
149
149
 
150
- def _unconstrained_vector_to_constrained_rvs(model):
151
- outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
152
- constrained_names = [
153
- x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
154
- ]
155
- names = [x.name for x in outputs]
150
+ def unpack_last_axis(packed_input, packed_shapes):
151
+ if len(packed_shapes) == 1:
152
+ # Single case currently fails in unpack
153
+ return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
156
154
 
157
- unconstrained_names = [name for name in names if name not in constrained_names]
155
+ keep_axes = tuple(range(packed_input.ndim))[:-1]
156
+ return pt.unpack(packed_input, axes=keep_axes, packed_shapes=packed_shapes)
158
157
 
159
- new_outputs, unconstrained_vector = join_nonshared_inputs(
160
- model.initial_point(),
161
- inputs=model.value_vars,
162
- outputs=outputs,
163
- )
164
-
165
- constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
166
- value_rvs = [x for x in new_outputs if x not in constrained_rvs]
167
-
168
- unconstrained_vector.name = "unconstrained_vector"
169
158
 
170
- # Redo the names list to ensure it is sorted to match the return order
171
- names = [*constrained_names, *unconstrained_names]
159
+ def draws_from_laplace_approx(
160
+ *,
161
+ mean,
162
+ covariance=None,
163
+ standard_deviation=None,
164
+ draws: int,
165
+ model: Model,
166
+ vectorize_draws: bool = True,
167
+ return_unconstrained: bool = True,
168
+ random_seed=None,
169
+ compile_kwargs: dict | None = None,
170
+ ) -> tuple[Dataset, Dataset | None]:
171
+ """
172
+ Generate draws from the Laplace approximation of the posterior.
172
173
 
173
- return names, constrained_rvs, value_rvs, unconstrained_vector
174
+ Parameters
175
+ ----------
176
+ mean : np.ndarray
177
+ The mean of the Laplace approximation (MAP estimate).
178
+ covariance : np.ndarray, optional
179
+ The covariance matrix of the Laplace approximation.
180
+ Mutually exclusive with `standard_deviation`.
181
+ standard_deviation : np.ndarray, optional
182
+ The standard deviation of the Laplace approximation (diagonal approximation).
183
+ Mutually exclusive with `covariance`.
184
+ draws : int
185
+ The number of draws.
186
+ model : pm.Model
187
+ The PyMC model.
188
+ vectorize_draws : bool, default True
189
+ Whether to vectorize the draws.
190
+ return_unconstrained : bool, default True
191
+ Whether to return the unconstrained draws in addition to the constrained ones.
192
+ random_seed : int, optional
193
+ Random seed for reproducibility.
194
+ compile_kwargs: dict, optional
195
+ Optional compile kwargs
174
196
 
197
+ Returns
198
+ -------
199
+ tuple[Dataset, Dataset | None]
200
+ A tuple containing the constrained draws (trace) and optionally the unconstrained draws.
201
+
202
+ Raises
203
+ ------
204
+ ValueError
205
+ If neither `covariance` nor `standard_deviation` is provided,
206
+ or if both are provided.
207
+ """
208
+ # This function assumes that mean/covariance/standard_deviation are aligned with model.initial_point()
209
+ if covariance is None and standard_deviation is None:
210
+ raise ValueError("Must specify either covariance or standard_deviation")
211
+ if covariance is not None and standard_deviation is not None:
212
+ raise ValueError("Cannot specify both covariance and standard_deviation")
213
+ if compile_kwargs is None:
214
+ compile_kwargs = {}
175
215
 
176
- def model_to_laplace_approx(
177
- model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
178
- ):
179
216
  initial_point = model.initial_point()
180
- raveled_vars = DictToArrayBijection.map(initial_point)
181
- raveled_shape = raveled_vars.data.shape[0]
182
-
183
- # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
184
- # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
185
- names, constrained_rvs, value_rvs, unconstrained_vector = (
186
- _unconstrained_vector_to_constrained_rvs(model)
217
+ n = int(np.sum([np.prod(v.shape) for v in initial_point.values()]))
218
+ assert mean.shape == (n,)
219
+ if covariance is not None:
220
+ assert covariance.shape == (n, n)
221
+ elif standard_deviation is not None:
222
+ assert standard_deviation.shape == (n,)
223
+
224
+ vars_to_sample = [v for v in model.free_RVs + model.deterministics]
225
+ var_names = [v.name for v in vars_to_sample]
226
+
227
+ orig_constrained_vars = model.value_vars
228
+ orig_outputs = model.replace_rvs_by_values(vars_to_sample)
229
+ if return_unconstrained:
230
+ orig_outputs.extend(model.value_vars)
231
+
232
+ mu_pt = pt.vector("mu", shape=(n,), dtype=mean.dtype)
233
+ size = (draws,) if vectorize_draws else ()
234
+ if covariance is not None:
235
+ sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
236
+ laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
237
+ else:
238
+ sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
239
+ laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
240
+
241
+ constrained_vars = unpack_last_axis(
242
+ laplace_approximation,
243
+ [initial_point[v.name].shape for v in orig_constrained_vars],
244
+ )
245
+ outputs = vectorize_graph(
246
+ orig_outputs, replace=dict(zip(orig_constrained_vars, constrained_vars))
187
247
  )
188
248
 
189
- coords = model.coords | {
190
- "temp_chain": np.arange(chains),
191
- "temp_draw": np.arange(draws),
192
- "unpacked_variable_names": unpacked_variable_names,
193
- }
194
-
195
- with pm.Model(coords=coords, model=None) as laplace_model:
196
- mu = pm.Flat("mean_vector", shape=(raveled_shape,))
197
- cov = pm.Flat("covariance_matrix", shape=(raveled_shape, raveled_shape))
198
- laplace_approximation = pm.MvNormal(
199
- "laplace_approximation",
200
- mu=mu,
201
- cov=cov,
202
- dims=["temp_chain", "temp_draw", "unpacked_variable_names"],
203
- method="svd",
204
- )
205
-
206
- cast_to_var = partial(type_cast, Variable)
207
- batched_rvs = vectorize_graph(
208
- type_cast(list[Variable], constrained_rvs),
209
- replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
210
- )
211
-
212
- for name, batched_rv in zip(names, batched_rvs):
213
- batch_dims = ("temp_chain", "temp_draw")
214
- if batched_rv.ndim == 2:
215
- dims = batch_dims
216
- elif name in model.named_vars_to_dims:
217
- dims = (*batch_dims, *model.named_vars_to_dims[name])
218
- else:
219
- dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
220
- initval = initial_point.get(name, None)
221
- dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
222
- laplace_model.add_coords(
223
- {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
224
- )
225
-
226
- pm.Deterministic(name, batched_rv, dims=dims)
227
-
228
- return laplace_model
229
-
230
-
231
- def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
232
- """
233
- The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
234
- in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
235
- single vector, it's not easy to work with.
236
-
237
- This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
238
- coordinates, where possible.
239
- """
240
- initial_point = DictToArrayBijection.map(model.initial_point())
241
-
242
- cursor = 0
243
- unstacked_laplace_draws = {}
244
- coords = model.coords | {"chain": range(chains), "draw": range(draws)}
245
-
246
- # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
247
- # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
248
- # add an arviz-style default dim and label.
249
- for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
250
- rv_dims = []
251
- for i, dim in enumerate(
252
- model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
253
- ):
254
- if coords.get(dim) and shape[i] == len(coords[dim]):
255
- rv_dims.append(dim)
256
- else:
257
- rv_dims.append(f"{name}_dim_{i}")
258
- coords[f"{name}_dim_{i}"] = np.arange(shape[i])
259
-
260
- dims = ("chain", "draw", *rv_dims)
261
-
262
- values = (
263
- laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
249
+ fn = pm.pytensorf.compile(
250
+ [mu_pt, sigma_pt],
251
+ outputs,
252
+ random_seed=random_seed,
253
+ trust_input=True,
254
+ **compile_kwargs,
255
+ )
256
+ sigma = covariance if covariance is not None else standard_deviation
257
+ if vectorize_draws:
258
+ output_buffers = fn(mean, sigma)
259
+ else:
260
+ # Take one draw to find the shape of the outputs
261
+ output_buffers = []
262
+ for out_draw in fn(mean, sigma):
263
+ output_buffer = np.empty((draws, *out_draw.shape), dtype=out_draw.dtype)
264
+ output_buffer[0] = out_draw
265
+ output_buffers.append(output_buffer)
266
+ # Fill one draws at a time
267
+ for i in range(1, draws):
268
+ for out_buffer, out_draw in zip(output_buffers, fn(mean, sigma)):
269
+ out_buffer[i] = out_draw
270
+
271
+ model_coords, model_dims = coords_and_dims_for_inferencedata(model)
272
+ posterior = {
273
+ var_name: out_buffer[None]
274
+ for var_name, out_buffer in (
275
+ zip(var_names, output_buffers, strict=not return_unconstrained)
264
276
  )
265
- unstacked_laplace_draws[name] = xr.DataArray(
266
- values, dims=dims, coords={dim: list(coords[dim]) for dim in dims}
277
+ }
278
+ posterior_dataset = dict_to_dataset(posterior, coords=model_coords, dims=model_dims, library=pm)
279
+ unconstrained_posterior_dataset = None
280
+
281
+ if return_unconstrained:
282
+ unconstrained_posterior = {
283
+ var.name: out_buffer[None]
284
+ for var, out_buffer in zip(
285
+ model.value_vars, output_buffers[len(posterior) :], strict=True
286
+ )
287
+ }
288
+ # Attempt to map constrained dims to unconstrained dims
289
+ for var_name, var_draws in unconstrained_posterior.items():
290
+ if not is_transformed_name(var_name):
291
+ # constrained == unconstrained, dims already shared
292
+ continue
293
+ constrained_dims = model_dims.get(get_untransformed_name(var_name))
294
+ if constrained_dims is None or (len(constrained_dims) != (var_draws.ndim - 2)):
295
+ continue
296
+ # Reuse dims from constrained variable if they match in length with unconstrained draws
297
+ inferred_dims = []
298
+ for i, (constrained_dim, unconstrained_dim_length) in enumerate(
299
+ zip(constrained_dims, var_draws.shape[2:], strict=True)
300
+ ):
301
+ if model_coords.get(constrained_dim) is not None and (
302
+ len(model_coords[constrained_dim]) == unconstrained_dim_length
303
+ ):
304
+ # Assume coordinates map. This could be fooled, by e.g., having a transform that reverses values
305
+ inferred_dims.append(constrained_dim)
306
+ else:
307
+ # Size mismatch (e.g., Simplex), make no assumption about mapping
308
+ inferred_dims.append(f"{var_name}_dim_{i}")
309
+ model_dims[var_name] = inferred_dims
310
+
311
+ unconstrained_posterior_dataset = dict_to_dataset(
312
+ unconstrained_posterior,
313
+ coords=model_coords,
314
+ dims=model_dims,
315
+ library=pm,
267
316
  )
268
317
 
269
- cursor += size
270
-
271
- unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
272
-
273
- return unstacked_laplace_draws
318
+ return posterior_dataset, unconstrained_posterior_dataset
274
319
 
275
320
 
276
321
  def fit_laplace(
@@ -285,9 +330,11 @@ def fit_laplace(
285
330
  jitter_rvs: list[pt.TensorVariable] | None = None,
286
331
  progressbar: bool = True,
287
332
  include_transformed: bool = True,
333
+ freeze_model: bool = True,
288
334
  gradient_backend: GradientBackend = "pytensor",
289
- chains: int = 2,
335
+ chains: None | int = None,
290
336
  draws: int = 500,
337
+ vectorize_draws: bool = True,
291
338
  optimizer_kwargs: dict | None = None,
292
339
  compile_kwargs: dict | None = None,
293
340
  ) -> az.InferenceData:
@@ -328,18 +375,20 @@ def fit_laplace(
328
375
  include_transformed: bool, default True
329
376
  Whether to include transformed variables in the output. If True, transformed variables will be included in the
330
377
  output InferenceData object. If False, only the original variables will be included.
378
+ freeze_model: bool, optional
379
+ If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
380
+ sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
381
+ True.
331
382
  gradient_backend: str, default "pytensor"
332
383
  The backend to use for gradient computations. Must be one of "pytensor" or "jax".
333
- chains: int, default: 2
334
- The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
335
- because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
336
- compatible with the ArviZ library.
337
384
  draws: int, default: 500
338
- The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
385
+ The number of samples to draw from the approximated posterior.
339
386
  optimizer_kwargs
340
387
  Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
341
388
  ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
342
389
  ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
390
+ vectorize_draws: bool, default True
391
+ Whether to natively vectorize the random function or take one at a time in a python loop.
343
392
  compile_kwargs: dict, optional
344
393
  Additional keyword arguments to pass to pytensor.function.
345
394
 
@@ -354,7 +403,7 @@ def fit_laplace(
354
403
  >>> import numpy as np
355
404
  >>> import pymc as pm
356
405
  >>> import arviz as az
357
- >>> y = np.array([2642, 3503, 4358]*10)
406
+ >>> y = np.array([2642, 3503, 4358] * 10)
358
407
  >>> with pm.Model() as m:
359
408
  >>> logsigma = pm.Uniform("logsigma", 1, 100)
360
409
  >>> mu = pm.Uniform("mu", -10000, 10000)
@@ -372,10 +421,19 @@ def fit_laplace(
372
421
  will forward the call to 'fit_laplace'.
373
422
 
374
423
  """
424
+ if chains is not None:
425
+ raise ValueError(
426
+ "chains argument has been deprecated. "
427
+ "The behavior can be recreated by unstacking draws into multiple chains after fitting"
428
+ )
429
+
375
430
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
376
431
  optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
377
432
  model = pm.modelcontext(model) if model is None else model
378
433
 
434
+ if freeze_model:
435
+ model = freeze_dims_and_data(model)
436
+
379
437
  idata = find_MAP(
380
438
  method=optimize_method,
381
439
  model=model,
@@ -387,17 +445,17 @@ def fit_laplace(
387
445
  jitter_rvs=jitter_rvs,
388
446
  progressbar=progressbar,
389
447
  include_transformed=include_transformed,
448
+ freeze_model=False,
390
449
  gradient_backend=gradient_backend,
391
450
  compile_kwargs=compile_kwargs,
392
451
  compute_hessian=True,
393
452
  **optimizer_kwargs,
394
453
  )
395
454
 
396
- unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
397
-
398
455
  if "covariance_matrix" not in idata.fit:
399
456
  # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
400
457
  # we have to go back and compute the Hessian at the MAP point now.
458
+ unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
401
459
  frozen_model = freeze_dims_and_data(model)
402
460
  initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
403
461
 
@@ -426,29 +484,16 @@ def fit_laplace(
426
484
  coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
427
485
  )
428
486
 
429
- with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model:
430
- new_posterior = (
431
- pm.sample_posterior_predictive(
432
- idata.fit.expand_dims(chain=[0], draw=[0]),
433
- extend_inferencedata=False,
434
- random_seed=random_seed,
435
- var_names=[
436
- "laplace_approximation",
437
- *[x.name for x in laplace_model.deterministics],
438
- ],
439
- )
440
- .posterior_predictive.squeeze(["chain", "draw"])
441
- .drop_vars(["chain", "draw"])
442
- .rename({"temp_chain": "chain", "temp_draw": "draw"})
443
- )
444
-
445
- if include_transformed:
446
- idata.unconstrained_posterior = unstack_laplace_draws(
447
- new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
448
- )
449
-
450
- idata.posterior = new_posterior.drop_vars(
451
- ["laplace_approximation", "unpacked_variable_names"]
452
- )
453
-
487
+ # We override the posterior/unconstrained_posterior from find_MAP
488
+ idata.posterior, unconstrained_posterior = draws_from_laplace_approx(
489
+ mean=idata.fit["mean_vector"].values,
490
+ covariance=idata.fit["covariance_matrix"].values,
491
+ draws=draws,
492
+ return_unconstrained=include_transformed,
493
+ model=model,
494
+ vectorize_draws=vectorize_draws,
495
+ random_seed=random_seed,
496
+ )
497
+ if include_transformed:
498
+ idata.unconstrained_posterior = unconstrained_posterior
454
499
  return idata
@@ -1,3 +1,5 @@
1
+ import logging
2
+
1
3
  from collections.abc import Callable
2
4
  from importlib.util import find_spec
3
5
  from typing import Literal, get_args
@@ -6,6 +8,7 @@ import numpy as np
6
8
  import pymc as pm
7
9
  import pytensor
8
10
 
11
+ from better_optimize.constants import MINIMIZE_MODE_KWARGS
9
12
  from pymc import join_nonshared_inputs
10
13
  from pytensor import tensor as pt
11
14
  from pytensor.compile import Function
@@ -14,6 +17,39 @@ from pytensor.tensor import TensorVariable
14
17
  GradientBackend = Literal["pytensor", "jax"]
15
18
  VALID_BACKENDS = get_args(GradientBackend)
16
19
 
20
+ _log = logging.getLogger(__name__)
21
+
22
+
23
+ def set_optimizer_function_defaults(
24
+ method: str, use_grad: bool | None, use_hess: bool | None, use_hessp: bool | None
25
+ ):
26
+ method_info = MINIMIZE_MODE_KWARGS[method].copy()
27
+
28
+ if use_hess and use_hessp:
29
+ _log.warning(
30
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
31
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
32
+ 'Setting "use_hess" to False.'
33
+ )
34
+ use_hess = False
35
+
36
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
37
+
38
+ if use_hessp is not None and use_hess is None:
39
+ use_hess = not use_hessp
40
+
41
+ elif use_hess is not None and use_hessp is None:
42
+ use_hessp = not use_hess
43
+
44
+ elif use_hessp is None and use_hess is None:
45
+ use_hessp = method_info["uses_hessp"]
46
+ use_hess = method_info["uses_hess"]
47
+ if use_hessp and use_hess:
48
+ # If a method could use either hess or hessp, we default to using hessp
49
+ use_hess = False
50
+
51
+ return use_grad, use_hess, use_hessp
52
+
17
53
 
18
54
  def _compile_grad_and_hess_to_jax(
19
55
  f_fused: Function, use_hess: bool, use_hessp: bool
@@ -144,12 +180,13 @@ def _compile_functions_for_scipy_optimize(
144
180
  def scipy_optimize_funcs_from_loss(
145
181
  loss: TensorVariable,
146
182
  inputs: list[TensorVariable],
147
- initial_point_dict: dict[str, np.ndarray | float | int],
148
- use_grad: bool,
149
- use_hess: bool,
150
- use_hessp: bool,
183
+ initial_point_dict: dict[str, np.ndarray | float | int] | None = None,
184
+ use_grad: bool | None = None,
185
+ use_hess: bool | None = None,
186
+ use_hessp: bool | None = None,
151
187
  gradient_backend: GradientBackend = "pytensor",
152
188
  compile_kwargs: dict | None = None,
189
+ inputs_are_flat: bool = False,
153
190
  ) -> tuple[Callable, ...]:
154
191
  """
155
192
  Compile loss functions for use with scipy.optimize.minimize.
@@ -206,9 +243,12 @@ def scipy_optimize_funcs_from_loss(
206
243
  if not isinstance(inputs, list):
207
244
  inputs = [inputs]
208
245
 
209
- [loss], flat_input = join_nonshared_inputs(
210
- point=initial_point_dict, outputs=[loss], inputs=inputs
211
- )
246
+ if inputs_are_flat:
247
+ [flat_input] = inputs
248
+ else:
249
+ [loss], flat_input = join_nonshared_inputs(
250
+ point=initial_point_dict, outputs=[loss], inputs=inputs
251
+ )
212
252
 
213
253
  # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
214
254
  # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them