pymc-extras 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. pymc_extras/inference/__init__.py +2 -2
  2. pymc_extras/inference/fit.py +1 -1
  3. pymc_extras/inference/laplace_approx/__init__.py +0 -0
  4. pymc_extras/inference/laplace_approx/find_map.py +354 -0
  5. pymc_extras/inference/laplace_approx/idata.py +393 -0
  6. pymc_extras/inference/laplace_approx/laplace.py +453 -0
  7. pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
  8. pymc_extras/inference/pathfinder/pathfinder.py +3 -4
  9. pymc_extras/linearmodel.py +3 -1
  10. pymc_extras/model/marginal/graph_analysis.py +4 -0
  11. pymc_extras/prior.py +38 -6
  12. pymc_extras/statespace/core/statespace.py +78 -52
  13. pymc_extras/statespace/filters/kalman_smoother.py +1 -1
  14. pymc_extras/statespace/models/structural/__init__.py +21 -0
  15. pymc_extras/statespace/models/structural/components/__init__.py +0 -0
  16. pymc_extras/statespace/models/structural/components/autoregressive.py +188 -0
  17. pymc_extras/statespace/models/structural/components/cycle.py +305 -0
  18. pymc_extras/statespace/models/structural/components/level_trend.py +257 -0
  19. pymc_extras/statespace/models/structural/components/measurement_error.py +137 -0
  20. pymc_extras/statespace/models/structural/components/regression.py +228 -0
  21. pymc_extras/statespace/models/structural/components/seasonality.py +445 -0
  22. pymc_extras/statespace/models/structural/core.py +900 -0
  23. pymc_extras/statespace/models/structural/utils.py +16 -0
  24. pymc_extras/statespace/models/utilities.py +285 -0
  25. pymc_extras/statespace/utils/constants.py +4 -4
  26. pymc_extras/statespace/utils/data_tools.py +3 -2
  27. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/METADATA +6 -6
  28. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/RECORD +30 -18
  29. pymc_extras/inference/find_map.py +0 -496
  30. pymc_extras/inference/laplace.py +0 -583
  31. pymc_extras/statespace/models/structural.py +0 -1679
  32. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/WHEEL +0 -0
  33. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,496 +0,0 @@
1
- import logging
2
-
3
- from collections.abc import Callable
4
- from importlib.util import find_spec
5
- from typing import Literal, cast, get_args
6
-
7
- import numpy as np
8
- import pymc as pm
9
- import pytensor
10
- import pytensor.tensor as pt
11
-
12
- from better_optimize import basinhopping, minimize
13
- from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
14
- from pymc.blocking import DictToArrayBijection, RaveledVars
15
- from pymc.initial_point import make_initial_point_fn
16
- from pymc.model.transform.optimization import freeze_dims_and_data
17
- from pymc.pytensorf import join_nonshared_inputs
18
- from pymc.util import get_default_varnames
19
- from pytensor.compile import Function
20
- from pytensor.compile.mode import Mode
21
- from pytensor.tensor import TensorVariable
22
- from scipy.optimize import OptimizeResult
23
-
24
- _log = logging.getLogger(__name__)
25
-
26
- GradientBackend = Literal["pytensor", "jax"]
27
- VALID_BACKENDS = get_args(GradientBackend)
28
-
29
-
30
- def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
31
- method_info = MINIMIZE_MODE_KWARGS[method].copy()
32
-
33
- if use_hess and use_hessp:
34
- _log.warning(
35
- 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36
- 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37
- 'Setting "use_hess" to False.'
38
- )
39
- use_hess = False
40
-
41
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42
-
43
- if use_hessp is not None and use_hess is None:
44
- use_hess = not use_hessp
45
-
46
- elif use_hess is not None and use_hessp is None:
47
- use_hessp = not use_hess
48
-
49
- elif use_hessp is None and use_hess is None:
50
- use_hessp = method_info["uses_hessp"]
51
- use_hess = method_info["uses_hess"]
52
- if use_hessp and use_hess:
53
- # If a method could use either hess or hessp, we default to using hessp
54
- use_hess = False
55
-
56
- return use_grad, use_hess, use_hessp
57
-
58
-
59
- def get_nearest_psd(A: np.ndarray) -> np.ndarray:
60
- """
61
- Compute the nearest positive semi-definite matrix to a given matrix.
62
-
63
- This function takes a square matrix and returns the nearest positive semi-definite matrix using
64
- eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
65
- of the Frobenius norm.
66
-
67
- Parameters
68
- ----------
69
- A : np.ndarray
70
- Input square matrix.
71
-
72
- Returns
73
- -------
74
- np.ndarray
75
- The nearest positive semi-definite matrix to the input matrix.
76
- """
77
- C = (A + A.T) / 2
78
- eigval, eigvec = np.linalg.eigh(C)
79
- eigval[eigval < 0] = 0
80
-
81
- return eigvec @ np.diag(eigval) @ eigvec.T
82
-
83
-
84
- def _unconstrained_vector_to_constrained_rvs(model):
85
- constrained_rvs, unconstrained_vector = join_nonshared_inputs(
86
- model.initial_point(),
87
- inputs=model.value_vars,
88
- outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False),
89
- )
90
-
91
- unconstrained_vector.name = "unconstrained_vector"
92
- return constrained_rvs, unconstrained_vector
93
-
94
-
95
- def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws):
96
- X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0]))
97
- out = []
98
- for rv, idx in slices.items():
99
- f = model.rvs_to_transforms[rv]
100
- untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx]
101
-
102
- if rv in out_shapes:
103
- new_shape = (chains, draws) + out_shapes[rv]
104
- untransformed_X = untransformed_X.reshape(new_shape)
105
-
106
- out.append(untransformed_X)
107
-
108
- f_untransform = pytensor.function(
109
- inputs=[pytensor.In(X, borrow=True)],
110
- outputs=pytensor.Out(out, borrow=True),
111
- mode=Mode(linker="py", optimizer="FAST_COMPILE"),
112
- )
113
- return f_untransform(posterior_draws)
114
-
115
-
116
- def _compile_grad_and_hess_to_jax(
117
- f_loss: Function, use_hess: bool, use_hessp: bool
118
- ) -> tuple[Callable | None, Callable | None]:
119
- """
120
- Compile loss function gradients using JAX.
121
-
122
- Parameters
123
- ----------
124
- f_loss: Function
125
- The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
126
- compiled with mode="JAX".
127
- use_hess: bool
128
- Whether to compile a function to compute the hessian of the loss function.
129
- use_hessp: bool
130
- Whether to compile a function to compute the hessian-vector product of the loss function.
131
-
132
- Returns
133
- -------
134
- f_loss_and_grad: Callable
135
- The compiled loss function and gradient function.
136
- f_hess: Callable | None
137
- The compiled hessian function, or None if use_hess is False.
138
- f_hessp: Callable | None
139
- The compiled hessian-vector product function, or None if use_hessp is False.
140
- """
141
- import jax
142
-
143
- f_hess = None
144
- f_hessp = None
145
-
146
- orig_loss_fn = f_loss.vm.jit_fn
147
-
148
- @jax.jit
149
- def loss_fn_jax_grad(x):
150
- return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
151
-
152
- f_loss_and_grad = loss_fn_jax_grad
153
-
154
- if use_hessp:
155
-
156
- def f_hessp_jax(x, p):
157
- y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,))
158
- return jax.numpy.stack(u)
159
-
160
- f_hessp = jax.jit(f_hessp_jax)
161
-
162
- if use_hess:
163
- _f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1])
164
-
165
- def f_hess_jax(x):
166
- return jax.numpy.stack(_f_hess_jax(x))
167
-
168
- f_hess = jax.jit(f_hess_jax)
169
-
170
- return f_loss_and_grad, f_hess, f_hessp
171
-
172
-
173
- def _compile_functions_for_scipy_optimize(
174
- loss: TensorVariable,
175
- inputs: list[TensorVariable],
176
- compute_grad: bool,
177
- compute_hess: bool,
178
- compute_hessp: bool,
179
- compile_kwargs: dict | None = None,
180
- ) -> list[Function] | list[Function, Function | None, Function | None]:
181
- """
182
- Compile loss functions for use with scipy.optimize.minimize.
183
-
184
- Parameters
185
- ----------
186
- loss: TensorVariable
187
- The loss function to compile.
188
- inputs: list[TensorVariable]
189
- A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines
190
- expect the function signature to be f(x, *args), where x is a 1D array of parameters.
191
- compute_grad: bool
192
- Whether to compile a function that computes the gradients of the loss function.
193
- compute_hess: bool
194
- Whether to compile a function that computes the Hessian of the loss function.
195
- compute_hessp: bool
196
- Whether to compile a function that computes the Hessian-vector product of the loss function.
197
- compile_kwargs: dict, optional
198
- Additional keyword arguments to pass to the ``pm.compile`` function.
199
-
200
- Returns
201
- -------
202
- f_loss: Function
203
-
204
- f_hess: Function | None
205
- f_hessp: Function | None
206
- """
207
- loss = pm.pytensorf.rewrite_pregrad(loss)
208
- f_hess = None
209
- f_hessp = None
210
-
211
- if compute_grad:
212
- grads = pytensor.gradient.grad(loss, inputs)
213
- grad = pt.concatenate([grad.ravel() for grad in grads])
214
- f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
215
- else:
216
- f_loss = pm.compile(inputs, loss, **compile_kwargs)
217
- return [f_loss]
218
-
219
- if compute_hess:
220
- hess = pytensor.gradient.jacobian(grad, inputs)[0]
221
- f_hess = pm.compile(inputs, hess, **compile_kwargs)
222
-
223
- if compute_hessp:
224
- p = pt.tensor("p", shape=inputs[0].type.shape)
225
- hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
226
- f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
227
-
228
- return [f_loss_and_grad, f_hess, f_hessp]
229
-
230
-
231
- def scipy_optimize_funcs_from_loss(
232
- loss: TensorVariable,
233
- inputs: list[TensorVariable],
234
- initial_point_dict: dict[str, np.ndarray | float | int],
235
- use_grad: bool,
236
- use_hess: bool,
237
- use_hessp: bool,
238
- gradient_backend: GradientBackend = "pytensor",
239
- compile_kwargs: dict | None = None,
240
- ) -> tuple[Callable, ...]:
241
- """
242
- Compile loss functions for use with scipy.optimize.minimize.
243
-
244
- Parameters
245
- ----------
246
- loss: TensorVariable
247
- The loss function to compile.
248
- inputs: list[TensorVariable]
249
- The input variables to the loss function.
250
- initial_point_dict: dict[str, np.ndarray | float | int]
251
- Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables.
252
- use_grad: bool
253
- Whether to compile a function that computes the gradients of the loss function.
254
- use_hess: bool
255
- Whether to compile a function that computes the Hessian of the loss function.
256
- use_hessp: bool
257
- Whether to compile a function that computes the Hessian-vector product of the loss function.
258
- gradient_backend: str, default "pytensor"
259
- Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
260
- compile_kwargs:
261
- Additional keyword arguments to pass to the ``pm.compile`` function.
262
-
263
- Returns
264
- -------
265
- f_loss: Callable
266
- The compiled loss function.
267
- f_hess: Callable | None
268
- The compiled hessian function, or None if use_hess is False.
269
- f_hessp: Callable | None
270
- The compiled hessian-vector product function, or None if use_hessp is False.
271
- """
272
-
273
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
274
-
275
- if (use_hess or use_hessp) and not use_grad:
276
- raise ValueError(
277
- "Cannot compute hessian or hessian-vector product without also computing the gradient"
278
- )
279
-
280
- if gradient_backend not in VALID_BACKENDS:
281
- raise ValueError(
282
- f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
283
- )
284
-
285
- use_jax_gradients = (gradient_backend == "jax") and use_grad
286
- if use_jax_gradients and not find_spec("jax"):
287
- raise ImportError("JAX must be installed to use JAX gradients")
288
-
289
- mode = compile_kwargs.get("mode", None)
290
- if mode is None and use_jax_gradients:
291
- compile_kwargs["mode"] = "JAX"
292
- elif mode != "JAX" and use_jax_gradients:
293
- raise ValueError(
294
- 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"'
295
- )
296
-
297
- if not isinstance(inputs, list):
298
- inputs = [inputs]
299
-
300
- [loss], flat_input = join_nonshared_inputs(
301
- point=initial_point_dict, outputs=[loss], inputs=inputs
302
- )
303
-
304
- # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
305
- # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
306
- # away.
307
- if use_jax_gradients:
308
- from pymc.sampling.jax import _replace_shared_variables
309
-
310
- [loss] = _replace_shared_variables([loss])
311
-
312
- compute_grad = use_grad and not use_jax_gradients
313
- compute_hess = use_hess and not use_jax_gradients
314
- compute_hessp = use_hessp and not use_jax_gradients
315
-
316
- funcs = _compile_functions_for_scipy_optimize(
317
- loss=loss,
318
- inputs=[flat_input],
319
- compute_grad=compute_grad,
320
- compute_hess=compute_hess,
321
- compute_hessp=compute_hessp,
322
- compile_kwargs=compile_kwargs,
323
- )
324
-
325
- # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
326
- f_loss = funcs.pop(0)
327
- f_hess = funcs.pop(0) if compute_grad else None
328
- f_hessp = funcs.pop(0) if compute_grad else None
329
-
330
- if use_jax_gradients:
331
- # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
332
- f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
333
-
334
- return f_loss, f_hess, f_hessp
335
-
336
-
337
- def find_MAP(
338
- method: minimize_method | Literal["basinhopping"],
339
- *,
340
- model: pm.Model | None = None,
341
- use_grad: bool | None = None,
342
- use_hessp: bool | None = None,
343
- use_hess: bool | None = None,
344
- initvals: dict | None = None,
345
- random_seed: int | np.random.Generator | None = None,
346
- return_raw: bool = False,
347
- jitter_rvs: list[TensorVariable] | None = None,
348
- progressbar: bool = True,
349
- include_transformed: bool = True,
350
- gradient_backend: GradientBackend = "pytensor",
351
- compile_kwargs: dict | None = None,
352
- **optimizer_kwargs,
353
- ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
354
- """
355
- Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
356
-
357
- Parameters
358
- ----------
359
- model : pm.Model
360
- The PyMC model to be fit. If None, the current model context is used.
361
- method : str
362
- The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363
- trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364
-
365
- See scipy.optimize.minimize documentation for details.
366
- use_grad : bool | None, optional
367
- Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
368
- the ``method``.
369
- use_hessp : bool | None, optional
370
- Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
371
- the ``method``.
372
- use_hess : bool | None, optional
373
- Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
374
- the ``method``.
375
- initvals : None | dict, optional
376
- Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
377
- If None, the model's default initial values are used.
378
- random_seed : None | int | np.random.Generator, optional
379
- Seed for the random number generator or a numpy Generator for reproducibility
380
- return_raw: bool | False, optinal
381
- Whether to also return the full output of `scipy.optimize.minimize`
382
- jitter_rvs : list of TensorVariables, optional
383
- Variables whose initial values should be jittered. If None, all variables are jittered.
384
- progressbar : bool, optional
385
- Whether to display a progress bar during optimization. Defaults to True.
386
- include_transformed: bool, optional
387
- Whether to include transformed variable values in the returned dictionary. Defaults to True.
388
- gradient_backend: str, default "pytensor"
389
- Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
390
- compile_kwargs: dict, optional
391
- Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
392
- **optimizer_kwargs
393
- Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394
- ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395
- ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
396
-
397
- Returns
398
- -------
399
- optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult]
400
- Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True,
401
- also returns the object returned by ``scipy.optimize.minimize``.
402
- """
403
- model = pm.modelcontext(model)
404
- frozen_model = freeze_dims_and_data(model)
405
-
406
- jitter_rvs = [] if jitter_rvs is None else jitter_rvs
407
- compile_kwargs = {} if compile_kwargs is None else compile_kwargs
408
-
409
- ipfn = make_initial_point_fn(
410
- model=frozen_model,
411
- jitter_rvs=set(jitter_rvs),
412
- return_transformed=True,
413
- overrides=initvals,
414
- )
415
-
416
- start_dict = ipfn(random_seed)
417
- vars_dict = {var.name: var for var in frozen_model.continuous_value_vars}
418
- initial_params = DictToArrayBijection.map(
419
- {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
420
- )
421
-
422
- do_basinhopping = method == "basinhopping"
423
- minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
424
-
425
- if do_basinhopping:
426
- # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427
- # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428
- # if one isn't provided.
429
-
430
- method = minimizer_kwargs.pop("method", "L-BFGS-B")
431
- minimizer_kwargs["method"] = method
432
-
433
- use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
434
- method, use_grad, use_hess, use_hessp
435
- )
436
-
437
- f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
438
- loss=-frozen_model.logp(jacobian=False),
439
- inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
440
- initial_point_dict=start_dict,
441
- use_grad=use_grad,
442
- use_hess=use_hess,
443
- use_hessp=use_hessp,
444
- gradient_backend=gradient_backend,
445
- compile_kwargs=compile_kwargs,
446
- )
447
-
448
- args = optimizer_kwargs.pop("args", None)
449
-
450
- # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
451
- # if so. That is why the jac argument is not passed here in either branch.
452
-
453
- if do_basinhopping:
454
- if "args" not in minimizer_kwargs:
455
- minimizer_kwargs["args"] = args
456
- if "hess" not in minimizer_kwargs:
457
- minimizer_kwargs["hess"] = f_hess
458
- if "hessp" not in minimizer_kwargs:
459
- minimizer_kwargs["hessp"] = f_hessp
460
- if "method" not in minimizer_kwargs:
461
- minimizer_kwargs["method"] = method
462
-
463
- optimizer_result = basinhopping(
464
- func=f_logp,
465
- x0=cast(np.ndarray[float], initial_params.data),
466
- progressbar=progressbar,
467
- minimizer_kwargs=minimizer_kwargs,
468
- **optimizer_kwargs,
469
- )
470
-
471
- else:
472
- optimizer_result = minimize(
473
- f=f_logp,
474
- x0=cast(np.ndarray[float], initial_params.data),
475
- args=args,
476
- hess=f_hess,
477
- hessp=f_hessp,
478
- progressbar=progressbar,
479
- method=method,
480
- **optimizer_kwargs,
481
- )
482
-
483
- raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
484
- unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
485
- unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
486
- DictToArrayBijection.rmap(raveled_optimized)
487
- )
488
-
489
- optimized_point = {
490
- var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
491
- }
492
-
493
- if return_raw:
494
- return optimized_point, optimizer_result
495
-
496
- return optimized_point