pymc-extras 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +4 -2
  6. pymc_extras/inference/__init__.py +8 -1
  7. pymc_extras/inference/dadvi/__init__.py +0 -0
  8. pymc_extras/inference/dadvi/dadvi.py +351 -0
  9. pymc_extras/inference/fit.py +5 -0
  10. pymc_extras/inference/laplace_approx/find_map.py +32 -47
  11. pymc_extras/inference/laplace_approx/idata.py +27 -6
  12. pymc_extras/inference/laplace_approx/laplace.py +24 -6
  13. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  14. pymc_extras/inference/pathfinder/idata.py +517 -0
  15. pymc_extras/inference/pathfinder/pathfinder.py +61 -7
  16. pymc_extras/model/marginal/graph_analysis.py +2 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/filters/kalman_filter.py +12 -11
  21. pymc_extras/statespace/filters/kalman_smoother.py +1 -3
  22. pymc_extras/statespace/filters/utilities.py +2 -5
  23. pymc_extras/statespace/models/DFM.py +834 -0
  24. pymc_extras/statespace/models/ETS.py +190 -198
  25. pymc_extras/statespace/models/SARIMAX.py +9 -21
  26. pymc_extras/statespace/models/VARMAX.py +22 -74
  27. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  28. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  29. pymc_extras/statespace/models/utilities.py +7 -0
  30. pymc_extras/statespace/utils/constants.py +3 -1
  31. pymc_extras/utils/model_equivalence.py +2 -2
  32. pymc_extras/utils/prior.py +10 -14
  33. pymc_extras/utils/spline.py +4 -10
  34. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
  35. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +37 -33
  36. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
  37. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -22,10 +22,15 @@ def make_default_labels(name: str, shape: tuple[int, ...]) -> list:
22
22
  return [list(range(dim)) for dim in shape]
23
23
 
24
24
 
25
- def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]:
25
+ def make_unpacked_variable_names(
26
+ names: list[str], model: pm.Model, var_name_to_model_var: dict[str, str] | None = None
27
+ ) -> list[str]:
26
28
  coords = model.coords
27
29
  initial_point = model.initial_point()
28
30
 
31
+ if var_name_to_model_var is None:
32
+ var_name_to_model_var = {}
33
+
29
34
  value_to_dim = {
30
35
  value.name: model.named_vars_to_dims.get(model.values_to_rvs[value].name, None)
31
36
  for value in model.value_vars
@@ -37,6 +42,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
37
42
 
38
43
  unpacked_variable_names = []
39
44
  for name in names:
45
+ name = var_name_to_model_var.get(name, name)
40
46
  shape = initial_point[name].shape
41
47
  if shape:
42
48
  dims = dims_dict.get(name)
@@ -109,7 +115,7 @@ def map_results_to_inference_data(
109
115
  x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
110
116
  ]
111
117
 
112
- unconstrained_names = set(all_varnames) - set(constrained_names)
118
+ unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
113
119
 
114
120
  idata = az.from_dict(
115
121
  posterior={
@@ -136,7 +142,10 @@ def map_results_to_inference_data(
136
142
 
137
143
 
138
144
  def add_fit_to_inference_data(
139
- idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
145
+ idata: az.InferenceData,
146
+ mu: RaveledVars,
147
+ H_inv: np.ndarray | None,
148
+ model: pm.Model | None = None,
140
149
  ) -> az.InferenceData:
141
150
  """
142
151
  Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
@@ -147,7 +156,7 @@ def add_fit_to_inference_data(
147
156
  An InferenceData object containing the approximated posterior samples.
148
157
  mu: RaveledVars
149
158
  The MAP estimate of the model parameters.
150
- H_inv: np.ndarray
159
+ H_inv: np.ndarray, optional
151
160
  The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
152
161
  model: Model, optional
153
162
  A PyMC model. If None, the model is taken from the current model context.
@@ -255,6 +264,7 @@ def optimizer_result_to_dataset(
255
264
  method: minimize_method | Literal["basinhopping"],
256
265
  mu: RaveledVars | None = None,
257
266
  model: pm.Model | None = None,
267
+ var_name_to_model_var: dict[str, str] | None = None,
258
268
  ) -> xr.Dataset:
259
269
  """
260
270
  Convert an OptimizeResult object to an xarray Dataset object.
@@ -265,6 +275,9 @@ def optimizer_result_to_dataset(
265
275
  The result of the optimization process.
266
276
  method: minimize_method or "basinhopping"
267
277
  The optimization method used.
278
+ var_name_to_model_var: dict, optional
279
+ Mapping between variables in the optimization result and the model variable names. Used when auxiliary
280
+ variables were introduced, e.g. in DADVI.
268
281
 
269
282
  Returns
270
283
  -------
@@ -276,7 +289,9 @@ def optimizer_result_to_dataset(
276
289
 
277
290
  model = pm.modelcontext(model) if model is None else model
278
291
  variable_names, *_ = zip(*mu.point_map_info)
279
- unpacked_variable_names = make_unpacked_variable_names(variable_names, model)
292
+ unpacked_variable_names = make_unpacked_variable_names(
293
+ variable_names, model, var_name_to_model_var
294
+ )
280
295
 
281
296
  data_vars = {}
282
297
 
@@ -365,6 +380,7 @@ def add_optimizer_result_to_inference_data(
365
380
  method: minimize_method | Literal["basinhopping"],
366
381
  mu: RaveledVars | None = None,
367
382
  model: pm.Model | None = None,
383
+ var_name_to_model_var: dict[str, str] | None = None,
368
384
  ) -> az.InferenceData:
369
385
  """
370
386
  Add the optimization result to an InferenceData object.
@@ -381,13 +397,18 @@ def add_optimizer_result_to_inference_data(
381
397
  The MAP estimate of the model parameters.
382
398
  model: Model, optional
383
399
  A PyMC model. If None, the model is taken from the current model context.
400
+ var_name_to_model_var: dict, optional
401
+ Mapping between variables in the optimization result and the model variable names. Used when auxiliary
402
+ variables were introduced, e.g. in DADVI.
384
403
 
385
404
  Returns
386
405
  -------
387
406
  idata: az.InferenceData
388
407
  The provided InferenceData, with the optimization results added to the "optimizer" group.
389
408
  """
390
- dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model)
409
+ dataset = optimizer_result_to_dataset(
410
+ result, method=method, mu=mu, model=model, var_name_to_model_var=var_name_to_model_var
411
+ )
391
412
  idata.add_groups({"optimizer_result": dataset})
392
413
 
393
414
  return idata
@@ -168,9 +168,13 @@ def _unconstrained_vector_to_constrained_rvs(model):
168
168
  unconstrained_vector.name = "unconstrained_vector"
169
169
 
170
170
  # Redo the names list to ensure it is sorted to match the return order
171
- names = [*constrained_names, *unconstrained_names]
171
+ constrained_rvs_and_names = [(rv, name) for rv, name in zip(constrained_rvs, constrained_names)]
172
+ value_rvs_and_names = [
173
+ (rv, name) for rv, name in zip(value_rvs, names) for name in unconstrained_names
174
+ ]
175
+ # names = [*constrained_names, *unconstrained_names]
172
176
 
173
- return names, constrained_rvs, value_rvs, unconstrained_vector
177
+ return constrained_rvs_and_names, value_rvs_and_names, unconstrained_vector
174
178
 
175
179
 
176
180
  def model_to_laplace_approx(
@@ -182,8 +186,11 @@ def model_to_laplace_approx(
182
186
 
183
187
  # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
184
188
  # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
185
- names, constrained_rvs, value_rvs, unconstrained_vector = (
186
- _unconstrained_vector_to_constrained_rvs(model)
189
+
190
+ # The model was frozen during the find_MAP procedure. To ensure we're operating on the same model, freeze it again.
191
+ frozen_model = freeze_dims_and_data(model)
192
+ constrained_rvs_and_names, _, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(
193
+ frozen_model
187
194
  )
188
195
 
189
196
  coords = model.coords | {
@@ -204,12 +211,13 @@ def model_to_laplace_approx(
204
211
  )
205
212
 
206
213
  cast_to_var = partial(type_cast, Variable)
214
+ constrained_rvs, constrained_names = zip(*constrained_rvs_and_names)
207
215
  batched_rvs = vectorize_graph(
208
216
  type_cast(list[Variable], constrained_rvs),
209
217
  replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
210
218
  )
211
219
 
212
- for name, batched_rv in zip(names, batched_rvs):
220
+ for name, batched_rv in zip(constrained_names, batched_rvs):
213
221
  batch_dims = ("temp_chain", "temp_draw")
214
222
  if batched_rv.ndim == 2:
215
223
  dims = batch_dims
@@ -285,6 +293,7 @@ def fit_laplace(
285
293
  jitter_rvs: list[pt.TensorVariable] | None = None,
286
294
  progressbar: bool = True,
287
295
  include_transformed: bool = True,
296
+ freeze_model: bool = True,
288
297
  gradient_backend: GradientBackend = "pytensor",
289
298
  chains: int = 2,
290
299
  draws: int = 500,
@@ -328,6 +337,10 @@ def fit_laplace(
328
337
  include_transformed: bool, default True
329
338
  Whether to include transformed variables in the output. If True, transformed variables will be included in the
330
339
  output InferenceData object. If False, only the original variables will be included.
340
+ freeze_model: bool, optional
341
+ If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
342
+ sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
343
+ True.
331
344
  gradient_backend: str, default "pytensor"
332
345
  The backend to use for gradient computations. Must be one of "pytensor" or "jax".
333
346
  chains: int, default: 2
@@ -354,7 +367,7 @@ def fit_laplace(
354
367
  >>> import numpy as np
355
368
  >>> import pymc as pm
356
369
  >>> import arviz as az
357
- >>> y = np.array([2642, 3503, 4358]*10)
370
+ >>> y = np.array([2642, 3503, 4358] * 10)
358
371
  >>> with pm.Model() as m:
359
372
  >>> logsigma = pm.Uniform("logsigma", 1, 100)
360
373
  >>> mu = pm.Uniform("mu", -10000, 10000)
@@ -376,6 +389,9 @@ def fit_laplace(
376
389
  optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
377
390
  model = pm.modelcontext(model) if model is None else model
378
391
 
392
+ if freeze_model:
393
+ model = freeze_dims_and_data(model)
394
+
379
395
  idata = find_MAP(
380
396
  method=optimize_method,
381
397
  model=model,
@@ -387,8 +403,10 @@ def fit_laplace(
387
403
  jitter_rvs=jitter_rvs,
388
404
  progressbar=progressbar,
389
405
  include_transformed=include_transformed,
406
+ freeze_model=False,
390
407
  gradient_backend=gradient_backend,
391
408
  compile_kwargs=compile_kwargs,
409
+ compute_hessian=True,
392
410
  **optimizer_kwargs,
393
411
  )
394
412
 
@@ -1,3 +1,5 @@
1
+ import logging
2
+
1
3
  from collections.abc import Callable
2
4
  from importlib.util import find_spec
3
5
  from typing import Literal, get_args
@@ -6,6 +8,7 @@ import numpy as np
6
8
  import pymc as pm
7
9
  import pytensor
8
10
 
11
+ from better_optimize.constants import MINIMIZE_MODE_KWARGS
9
12
  from pymc import join_nonshared_inputs
10
13
  from pytensor import tensor as pt
11
14
  from pytensor.compile import Function
@@ -14,6 +17,39 @@ from pytensor.tensor import TensorVariable
14
17
  GradientBackend = Literal["pytensor", "jax"]
15
18
  VALID_BACKENDS = get_args(GradientBackend)
16
19
 
20
+ _log = logging.getLogger(__name__)
21
+
22
+
23
+ def set_optimizer_function_defaults(
24
+ method: str, use_grad: bool | None, use_hess: bool | None, use_hessp: bool | None
25
+ ):
26
+ method_info = MINIMIZE_MODE_KWARGS[method].copy()
27
+
28
+ if use_hess and use_hessp:
29
+ _log.warning(
30
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
31
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
32
+ 'Setting "use_hess" to False.'
33
+ )
34
+ use_hess = False
35
+
36
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
37
+
38
+ if use_hessp is not None and use_hess is None:
39
+ use_hess = not use_hessp
40
+
41
+ elif use_hess is not None and use_hessp is None:
42
+ use_hessp = not use_hess
43
+
44
+ elif use_hessp is None and use_hess is None:
45
+ use_hessp = method_info["uses_hessp"]
46
+ use_hess = method_info["uses_hess"]
47
+ if use_hessp and use_hess:
48
+ # If a method could use either hess or hessp, we default to using hessp
49
+ use_hess = False
50
+
51
+ return use_grad, use_hess, use_hessp
52
+
17
53
 
18
54
  def _compile_grad_and_hess_to_jax(
19
55
  f_fused: Function, use_hess: bool, use_hessp: bool
@@ -144,12 +180,13 @@ def _compile_functions_for_scipy_optimize(
144
180
  def scipy_optimize_funcs_from_loss(
145
181
  loss: TensorVariable,
146
182
  inputs: list[TensorVariable],
147
- initial_point_dict: dict[str, np.ndarray | float | int],
148
- use_grad: bool,
149
- use_hess: bool,
150
- use_hessp: bool,
183
+ initial_point_dict: dict[str, np.ndarray | float | int] | None = None,
184
+ use_grad: bool | None = None,
185
+ use_hess: bool | None = None,
186
+ use_hessp: bool | None = None,
151
187
  gradient_backend: GradientBackend = "pytensor",
152
188
  compile_kwargs: dict | None = None,
189
+ inputs_are_flat: bool = False,
153
190
  ) -> tuple[Callable, ...]:
154
191
  """
155
192
  Compile loss functions for use with scipy.optimize.minimize.
@@ -206,9 +243,12 @@ def scipy_optimize_funcs_from_loss(
206
243
  if not isinstance(inputs, list):
207
244
  inputs = [inputs]
208
245
 
209
- [loss], flat_input = join_nonshared_inputs(
210
- point=initial_point_dict, outputs=[loss], inputs=inputs
211
- )
246
+ if inputs_are_flat:
247
+ [flat_input] = inputs
248
+ else:
249
+ [loss], flat_input = join_nonshared_inputs(
250
+ point=initial_point_dict, outputs=[loss], inputs=inputs
251
+ )
212
252
 
213
253
  # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
214
254
  # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them