pymc-extras 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/PKG-INFO +4 -3
  2. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/__init__.py +2 -0
  3. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/find_map.py +36 -16
  4. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/laplace.py +17 -10
  5. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/importance_sampling.py +23 -17
  6. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/pathfinder.py +55 -23
  7. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/marginal_model.py +2 -1
  8. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/compile.py +1 -1
  9. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/statespace.py +5 -4
  10. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/distributions.py +9 -45
  11. pymc_extras-0.2.4/pymc_extras/version.txt +1 -0
  12. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/PKG-INFO +4 -3
  13. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/requires.txt +2 -1
  14. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pyproject.toml +3 -0
  15. pymc_extras-0.2.4/requirements.txt +3 -0
  16. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/setup.py +1 -1
  17. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_find_map.py +19 -14
  18. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_laplace.py +42 -15
  19. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_pathfinder.py +40 -10
  20. pymc_extras-0.2.2/pymc_extras/version.txt +0 -1
  21. pymc_extras-0.2.2/requirements.txt +0 -2
  22. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/CODE_OF_CONDUCT.md +0 -0
  23. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/CONTRIBUTING.md +0 -0
  24. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/LICENSE +0 -0
  25. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/MANIFEST.in +0 -0
  26. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/README.md +0 -0
  27. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/__init__.py +0 -0
  28. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/continuous.py +0 -0
  29. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/discrete.py +0 -0
  30. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/histogram_utils.py +0 -0
  31. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  32. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
  33. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/distributions/timeseries.py +0 -0
  34. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/gp/__init__.py +0 -0
  35. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/gp/latent_approx.py +0 -0
  36. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/__init__.py +0 -0
  37. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/fit.py +0 -0
  38. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/__init__.py +0 -0
  39. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
  40. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/smc/__init__.py +0 -0
  41. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/inference/smc/sampling.py +0 -0
  42. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/linearmodel.py +0 -0
  43. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/__init__.py +0 -0
  44. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/__init__.py +0 -0
  45. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/distributions.py +0 -0
  46. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/marginal/graph_analysis.py +0 -0
  47. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/model_api.py +0 -0
  48. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/transforms/__init__.py +0 -0
  49. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model/transforms/autoreparam.py +0 -0
  50. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/model_builder.py +0 -0
  51. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/preprocessing/__init__.py +0 -0
  52. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  53. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/printing.py +0 -0
  54. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/__init__.py +0 -0
  55. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/__init__.py +0 -0
  56. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/core/representation.py +0 -0
  57. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/__init__.py +0 -0
  58. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
  59. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
  60. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/filters/utilities.py +0 -0
  61. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/ETS.py +0 -0
  62. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/SARIMAX.py +0 -0
  63. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/VARMAX.py +0 -0
  64. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/__init__.py +0 -0
  65. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/structural.py +0 -0
  66. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/models/utilities.py +0 -0
  67. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/__init__.py +0 -0
  68. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/constants.py +0 -0
  69. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  70. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/statespace/utils/data_tools.py +0 -0
  71. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/__init__.py +0 -0
  72. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/linear_cg.py +0 -0
  73. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/model_equivalence.py +0 -0
  74. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/pivoted_cholesky.py +0 -0
  75. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/prior.py +0 -0
  76. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/utils/spline.py +0 -0
  77. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras/version.py +0 -0
  78. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/SOURCES.txt +0 -0
  79. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/dependency_links.txt +0 -0
  80. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/pymc_extras.egg-info/top_level.txt +0 -0
  81. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/requirements-dev.txt +0 -0
  82. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/requirements-docs.txt +0 -0
  83. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/setup.cfg +0 -0
  84. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/__init__.py +0 -0
  85. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/__init__.py +0 -0
  86. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_continuous.py +0 -0
  87. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_discrete.py +0 -0
  88. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_discrete_markov_chain.py +0 -0
  89. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/distributions/test_multivariate.py +0 -0
  90. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/__init__.py +0 -0
  91. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/__init__.py +0 -0
  92. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/test_distributions.py +0 -0
  93. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/test_graph_analysis.py +0 -0
  94. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/marginal/test_marginal_model.py +0 -0
  95. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/model/test_model_api.py +0 -0
  96. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/__init__.py +0 -0
  97. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_ETS.py +0 -0
  98. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_SARIMAX.py +0 -0
  99. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_VARMAX.py +0 -0
  100. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_coord_assignment.py +0 -0
  101. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_distributions.py +0 -0
  102. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_kalman_filter.py +0 -0
  103. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_representation.py +0 -0
  104. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_statespace.py +0 -0
  105. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_statespace_JAX.py +0 -0
  106. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/test_structural.py +0 -0
  107. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/__init__.py +0 -0
  108. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/shared_fixtures.py +0 -0
  109. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
  110. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/statespace/utilities/test_helpers.py +0 -0
  111. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_blackjax_smc.py +0 -0
  112. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_histogram_approximation.py +0 -0
  113. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_linearmodel.py +0 -0
  114. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_model_builder.py +0 -0
  115. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_pivoted_cholesky.py +0 -0
  116. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_printing.py +0 -0
  117. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_prior_from_trace.py +0 -0
  118. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/test_splines.py +0 -0
  119. {pymc_extras-0.2.2 → pymc_extras-0.2.4}/tests/utils.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
7
7
  Maintainer-email: pymc.devs@gmail.com
8
- License: Apache License, Version 2.0
8
+ License: Apache-2.0
9
9
  Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Programming Language :: Python
11
11
  Classifier: Programming Language :: Python :: 3
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.10
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.19.1
23
+ Requires-Dist: pymc>=5.21.1
24
24
  Requires-Dist: scikit-learn
25
+ Requires-Dist: better-optimize
25
26
  Provides-Extra: dask-histogram
26
27
  Requires-Dist: dask[complete]; extra == "dask-histogram"
27
28
  Requires-Dist: xhistogram; extra == "dask-histogram"
@@ -15,7 +15,9 @@ import logging
15
15
 
16
16
  from pymc_extras import gp, statespace, utils
17
17
  from pymc_extras.distributions import *
18
+ from pymc_extras.inference.find_map import find_MAP
18
19
  from pymc_extras.inference.fit import fit
20
+ from pymc_extras.inference.laplace import fit_laplace
19
21
  from pymc_extras.model.marginal.marginal_model import (
20
22
  MarginalModel,
21
23
  marginalize,
@@ -1,9 +1,9 @@
1
1
  import logging
2
2
 
3
3
  from collections.abc import Callable
4
+ from importlib.util import find_spec
4
5
  from typing import Literal, cast, get_args
5
6
 
6
- import jax
7
7
  import numpy as np
8
8
  import pymc as pm
9
9
  import pytensor
@@ -30,13 +30,29 @@ VALID_BACKENDS = get_args(GradientBackend)
30
30
  def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
31
31
  method_info = MINIMIZE_MODE_KWARGS[method].copy()
32
32
 
33
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34
- use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35
- use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36
-
37
33
  if use_hess and use_hessp:
34
+ _log.warning(
35
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37
+ 'Setting "use_hess" to False.'
38
+ )
38
39
  use_hess = False
39
40
 
41
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42
+
43
+ if use_hessp is not None and use_hess is None:
44
+ use_hess = not use_hessp
45
+
46
+ elif use_hess is not None and use_hessp is None:
47
+ use_hessp = not use_hess
48
+
49
+ elif use_hessp is None and use_hess is None:
50
+ use_hessp = method_info["uses_hessp"]
51
+ use_hess = method_info["uses_hess"]
52
+ if use_hessp and use_hess:
53
+ # If a method could use either hess or hessp, we default to using hessp
54
+ use_hess = False
55
+
40
56
  return use_grad, use_hess, use_hessp
41
57
 
42
58
 
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
59
75
  The nearest positive semi-definite matrix to the input matrix.
60
76
  """
61
77
  C = (A + A.T) / 2
62
- eigval, eigvec = np.linalg.eig(C)
78
+ eigval, eigvec = np.linalg.eigh(C)
63
79
  eigval[eigval < 0] = 0
64
80
 
65
81
  return eigvec @ np.diag(eigval) @ eigvec.T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97
113
  return f_untransform(posterior_draws)
98
114
 
99
115
 
100
- def _compile_jax_gradients(
116
+ def _compile_grad_and_hess_to_jax(
101
117
  f_loss: Function, use_hess: bool, use_hessp: bool
102
118
  ) -> tuple[Callable | None, Callable | None]:
103
119
  """
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122
138
  f_hessp: Callable | None
123
139
  The compiled hessian-vector product function, or None if use_hessp is False.
124
140
  """
141
+ import jax
142
+
125
143
  f_hess = None
126
144
  f_hessp = None
127
145
 
@@ -152,7 +170,7 @@ def _compile_jax_gradients(
152
170
  return f_loss_and_grad, f_hess, f_hessp
153
171
 
154
172
 
155
- def _compile_functions(
173
+ def _compile_functions_for_scipy_optimize(
156
174
  loss: TensorVariable,
157
175
  inputs: list[TensorVariable],
158
176
  compute_grad: bool,
@@ -177,7 +195,7 @@ def _compile_functions(
177
195
  compute_hessp: bool
178
196
  Whether to compile a function that computes the Hessian-vector product of the loss function.
179
197
  compile_kwargs: dict, optional
180
- Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
198
+ Additional keyword arguments to pass to the ``pm.compile`` function.
181
199
 
182
200
  Returns
183
201
  -------
@@ -193,19 +211,19 @@ def _compile_functions(
193
211
  if compute_grad:
194
212
  grads = pytensor.gradient.grad(loss, inputs)
195
213
  grad = pt.concatenate([grad.ravel() for grad in grads])
196
- f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
214
+ f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197
215
  else:
198
- f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
216
+ f_loss = pm.compile(inputs, loss, **compile_kwargs)
199
217
  return [f_loss]
200
218
 
201
219
  if compute_hess:
202
220
  hess = pytensor.gradient.jacobian(grad, inputs)[0]
203
- f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
221
+ f_hess = pm.compile(inputs, hess, **compile_kwargs)
204
222
 
205
223
  if compute_hessp:
206
224
  p = pt.tensor("p", shape=inputs[0].type.shape)
207
225
  hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208
- f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
226
+ f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209
227
 
210
228
  return [f_loss_and_grad, f_hess, f_hessp]
211
229
 
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240
258
  gradient_backend: str, default "pytensor"
241
259
  Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242
260
  compile_kwargs:
243
- Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
261
+ Additional keyword arguments to pass to the ``pm.compile`` function.
244
262
 
245
263
  Returns
246
264
  -------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265
283
  )
266
284
 
267
285
  use_jax_gradients = (gradient_backend == "jax") and use_grad
286
+ if use_jax_gradients and not find_spec("jax"):
287
+ raise ImportError("JAX must be installed to use JAX gradients")
268
288
 
269
289
  mode = compile_kwargs.get("mode", None)
270
290
  if mode is None and use_jax_gradients:
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285
305
  compute_hess = use_hess and not use_jax_gradients
286
306
  compute_hessp = use_hessp and not use_jax_gradients
287
307
 
288
- funcs = _compile_functions(
308
+ funcs = _compile_functions_for_scipy_optimize(
289
309
  loss=loss,
290
310
  inputs=[flat_input],
291
311
  compute_grad=compute_grad,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301
321
 
302
322
  if use_jax_gradients:
303
323
  # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304
- f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
324
+ f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305
325
 
306
326
  return f_loss, f_hess, f_hessp
307
327
 
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from functools import reduce
19
+ from importlib.util import find_spec
19
20
  from itertools import product
20
21
  from typing import Literal
21
22
 
@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
231
232
  return idata
232
233
 
233
234
 
234
- def fit_mvn_to_MAP(
235
+ def fit_mvn_at_MAP(
235
236
  optimized_point: dict[str, np.ndarray],
236
237
  model: pm.Model | None = None,
237
238
  on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
276
277
  inverse_hessian: np.ndarray
277
278
  The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
278
279
  """
280
+ if gradient_backend == "jax" and not find_spec("jax"):
281
+ raise ImportError("JAX must be installed to use JAX gradients")
282
+
279
283
  model = pm.modelcontext(model)
280
284
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
281
285
  frozen_model = freeze_dims_and_data(model)
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
344
348
 
345
349
  Parameters
346
350
  ----------
347
- mu
348
- H_inv
351
+ mu: RaveledVars
352
+ The MAP estimate of the model parameters.
353
+ H_inv: np.ndarray
354
+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349
355
  model : Model
350
356
  A PyMC model
351
357
  chains : int
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
384
390
  constrained_rvs, replace={unconstrained_vector: batched_values}
385
391
  )
386
392
 
387
- f_constrain = pm.compile_pymc(
388
- inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389
- )
393
+ f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390
394
  posterior_draws = f_constrain(posterior_draws)
391
395
 
392
396
  else:
@@ -472,15 +476,17 @@ def fit_laplace(
472
476
  and 1).
473
477
 
474
478
  .. warning::
475
- This argumnet should be considered highly experimental. It has not been verified if this method produces
479
+ This argument should be considered highly experimental. It has not been verified if this method produces
476
480
  valid draws from the posterior. **Use at your own risk**.
477
481
 
478
482
  gradient_backend: str, default "pytensor"
479
483
  The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480
484
  chains: int, default: 2
481
- The number of sampling chains running in parallel.
485
+ The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
486
+ because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
487
+ compatible with the ArviZ library.
482
488
  draws: int, default: 500
483
- The number of samples to draw from the approximated posterior.
489
+ The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484
490
  on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485
491
  What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486
492
  If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,11 +553,12 @@ def fit_laplace(
547
553
  **optimizer_kwargs,
548
554
  )
549
555
 
550
- mu, H_inv = fit_mvn_to_MAP(
556
+ mu, H_inv = fit_mvn_at_MAP(
551
557
  optimized_point=optimized_point,
552
558
  model=model,
553
559
  on_bad_cov=on_bad_cov,
554
560
  transform_samples=fit_in_unconstrained_space,
561
+ gradient_backend=gradient_backend,
555
562
  zero_tol=zero_tol,
556
563
  diag_jitter=diag_jitter,
557
564
  compile_kwargs=compile_kwargs,
@@ -20,7 +20,7 @@ class ImportanceSamplingResult:
20
20
  samples: NDArray
21
21
  pareto_k: float | None = None
22
22
  warnings: list[str] = field(default_factory=list)
23
- method: str = "none"
23
+ method: str = "psis"
24
24
 
25
25
 
26
26
  def importance_sampling(
@@ -28,7 +28,7 @@ def importance_sampling(
28
28
  logP: NDArray,
29
29
  logQ: NDArray,
30
30
  num_draws: int,
31
- method: Literal["psis", "psir", "identity", "none"] | None,
31
+ method: Literal["psis", "psir", "identity"] | None,
32
32
  random_seed: int | None = None,
33
33
  ) -> ImportanceSamplingResult:
34
34
  """Pareto Smoothed Importance Resampling (PSIR)
@@ -44,8 +44,15 @@ def importance_sampling(
44
44
  log probability values of proposal distribution, shape (L, M)
45
45
  num_draws : int
46
46
  number of draws to return where num_draws <= samples.shape[0]
47
- method : str, optional
48
- importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
47
+ method : str, None, optional
48
+ Method to apply sampling based on log importance weights (logP - logQ).
49
+ Options are:
50
+ "psis" : Pareto Smoothed Importance Sampling (default)
51
+ Recommended for more stable results.
52
+ "psir" : Pareto Smoothed Importance Resampling
53
+ Less stable than PSIS.
54
+ "identity" : Applies log importance weights directly without resampling.
55
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
49
56
  random_seed : int | None
50
57
 
51
58
  Returns
@@ -71,11 +78,11 @@ def importance_sampling(
71
78
  warnings = []
72
79
  num_paths, _, N = samples.shape
73
80
 
74
- if method == "none":
81
+ if method is None:
75
82
  warnings.append(
76
83
  "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
77
84
  )
78
- return ImportanceSamplingResult(samples=samples, warnings=warnings)
85
+ return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
79
86
  else:
80
87
  samples = samples.reshape(-1, N)
81
88
  logP = logP.ravel()
@@ -91,17 +98,16 @@ def importance_sampling(
91
98
  _warnings.filterwarnings(
92
99
  "ignore", category=RuntimeWarning, message="overflow encountered in exp"
93
100
  )
94
- if method == "psis":
95
- replace = False
96
- logiw, pareto_k = az.psislw(logiw)
97
- elif method == "psir":
98
- replace = True
99
- logiw, pareto_k = az.psislw(logiw)
100
- elif method == "identity":
101
- replace = False
102
- pareto_k = None
103
- else:
104
- raise ValueError(f"Invalid importance sampling method: {method}")
101
+ match method:
102
+ case "psis":
103
+ replace = False
104
+ logiw, pareto_k = az.psislw(logiw)
105
+ case "psir":
106
+ replace = True
107
+ logiw, pareto_k = az.psislw(logiw)
108
+ case "identity":
109
+ replace = False
110
+ pareto_k = None
105
111
 
106
112
  # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107
113
  # Pareto k may not be a good diagnostic for Pathfinder.
@@ -60,6 +60,7 @@ from pytensor.graph import Apply, Op, vectorize_graph
60
60
  from pytensor.tensor import TensorConstant, TensorVariable
61
61
  from rich.console import Console, Group
62
62
  from rich.padding import Padding
63
+ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
63
64
  from rich.table import Table
64
65
  from rich.text import Text
65
66
 
@@ -155,7 +156,7 @@ def convert_flat_trace_to_idata(
155
156
  postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
156
157
  inference_backend: Literal["pymc", "blackjax"] = "pymc",
157
158
  model: Model | None = None,
158
- importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
159
+ importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
159
160
  ) -> az.InferenceData:
160
161
  """convert flattened samples to arviz InferenceData format.
161
162
 
@@ -180,7 +181,7 @@ def convert_flat_trace_to_idata(
180
181
  arviz inference data object
181
182
  """
182
183
 
183
- if importance_sampling == "none":
184
+ if importance_sampling is None:
184
185
  # samples.ndim == 3 in this case, otherwise ndim == 2
185
186
  num_paths, num_pdraws, N = samples.shape
186
187
  samples = samples.reshape(-1, N)
@@ -219,7 +220,7 @@ def convert_flat_trace_to_idata(
219
220
  fn.trust_input = True
220
221
  result = fn(*list(trace.values()))
221
222
 
222
- if importance_sampling == "none":
223
+ if importance_sampling is None:
223
224
  result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
224
225
 
225
226
  elif inference_backend == "blackjax":
@@ -1188,7 +1189,7 @@ class MultiPathfinderResult:
1188
1189
  elbo_argmax: NDArray | None = None
1189
1190
  lbfgs_status: Counter = field(default_factory=Counter)
1190
1191
  path_status: Counter = field(default_factory=Counter)
1191
- importance_sampling: str = "none"
1192
+ importance_sampling: str | None = "psis"
1192
1193
  warnings: list[str] = field(default_factory=list)
1193
1194
  pareto_k: float | None = None
1194
1195
 
@@ -1257,7 +1258,7 @@ class MultiPathfinderResult:
1257
1258
  def with_importance_sampling(
1258
1259
  self,
1259
1260
  num_draws: int,
1260
- method: Literal["psis", "psir", "identity", "none"] | None,
1261
+ method: Literal["psis", "psir", "identity"] | None,
1261
1262
  random_seed: int | None = None,
1262
1263
  ) -> Self:
1263
1264
  """perform importance sampling"""
@@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
1395
1396
 
1396
1397
  path_status_message = {
1397
1398
  PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1398
- PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1399
+ PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1399
1400
  PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1400
1401
  }
1401
1402
 
@@ -1423,7 +1424,7 @@ def multipath_pathfinder(
1423
1424
  num_elbo_draws: int,
1424
1425
  jitter: float,
1425
1426
  epsilon: float,
1426
- importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
1427
+ importance_sampling: Literal["psis", "psir", "identity"] | None,
1427
1428
  progressbar: bool,
1428
1429
  concurrent: Literal["thread", "process"] | None,
1429
1430
  random_seed: RandomSeed,
@@ -1459,8 +1460,14 @@ def multipath_pathfinder(
1459
1460
  Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1460
1461
  epsilon: float
1461
1462
  value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1462
- importance_sampling : str, optional
1463
- importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1463
+ importance_sampling : str, None, optional
1464
+ Method to apply sampling based on log importance weights (logP - logQ).
1465
+ "psis" : Pareto Smoothed Importance Sampling (default)
1466
+ Recommended for more stable results.
1467
+ "psir" : Pareto Smoothed Importance Resampling
1468
+ Less stable than PSIS.
1469
+ "identity" : Applies log importance weights directly without resampling.
1470
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1464
1471
  progressbar : bool, optional
1465
1472
  Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
1466
1473
  random_seed : RandomSeed, optional
@@ -1482,12 +1489,6 @@ def multipath_pathfinder(
1482
1489
  The result containing samples and other information from the Multi-Path Pathfinder algorithm.
1483
1490
  """
1484
1491
 
1485
- valid_importance_sampling = ["psis", "psir", "identity", "none", None]
1486
- if importance_sampling is None:
1487
- importance_sampling = "none"
1488
- if importance_sampling.lower() not in valid_importance_sampling:
1489
- raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1490
-
1491
1492
  *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
1492
1493
 
1493
1494
  pathfinder_config = PathfinderConfig(
@@ -1521,12 +1522,20 @@ def multipath_pathfinder(
1521
1522
  results = []
1522
1523
  compute_start = time.time()
1523
1524
  try:
1524
- with CustomProgress(
1525
+ desc = f"Paths Complete: {{path_idx}}/{num_paths}"
1526
+ progress = CustomProgress(
1527
+ "[progress.description]{task.description}",
1528
+ BarColumn(),
1529
+ "[progress.percentage]{task.percentage:>3.0f}%",
1530
+ TimeRemainingColumn(),
1531
+ TextColumn("/"),
1532
+ TimeElapsedColumn(),
1525
1533
  console=Console(theme=default_progress_theme),
1526
1534
  disable=not progressbar,
1527
- ) as progress:
1528
- task = progress.add_task("Fitting", total=num_paths)
1529
- for result in generator:
1535
+ )
1536
+ with progress:
1537
+ task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
1538
+ for path_idx, result in enumerate(generator, start=1):
1530
1539
  try:
1531
1540
  if isinstance(result, Exception):
1532
1541
  raise result
@@ -1552,7 +1561,14 @@ def multipath_pathfinder(
1552
1561
  lbfgs_status=LBFGSStatus.LBFGS_FAILED,
1553
1562
  )
1554
1563
  )
1555
- progress.update(task, advance=1)
1564
+ finally:
1565
+ # TODO: display LBFGS and Path Status in real time
1566
+ progress.update(
1567
+ task,
1568
+ description=desc.format(path_idx=path_idx),
1569
+ completed=path_idx,
1570
+ refresh=True,
1571
+ )
1556
1572
  except (KeyboardInterrupt, StopIteration) as e:
1557
1573
  # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
1558
1574
  if isinstance(e, StopIteration):
@@ -1606,7 +1622,7 @@ def fit_pathfinder(
1606
1622
  num_elbo_draws: int = 10, # K
1607
1623
  jitter: float = 2.0,
1608
1624
  epsilon: float = 1e-8,
1609
- importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
1625
+ importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
1610
1626
  progressbar: bool = True,
1611
1627
  concurrent: Literal["thread", "process"] | None = None,
1612
1628
  random_seed: RandomSeed | None = None,
@@ -1646,8 +1662,15 @@ def fit_pathfinder(
1646
1662
  Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1647
1663
  epsilon: float
1648
1664
  value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1649
- importance_sampling : str, optional
1650
- importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1665
+ importance_sampling : str, None, optional
1666
+ Method to apply sampling based on log importance weights (logP - logQ).
1667
+ Options are:
1668
+ "psis" : Pareto Smoothed Importance Sampling (default)
1669
+ Recommended for more stable results.
1670
+ "psir" : Pareto Smoothed Importance Resampling
1671
+ Less stable than PSIS.
1672
+ "identity" : Applies log importance weights directly without resampling.
1673
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1651
1674
  progressbar : bool, optional
1652
1675
  Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
1653
1676
  random_seed : RandomSeed, optional
@@ -1674,6 +1697,15 @@ def fit_pathfinder(
1674
1697
  """
1675
1698
 
1676
1699
  model = modelcontext(model)
1700
+
1701
+ valid_importance_sampling = {"psis", "psir", "identity", None}
1702
+
1703
+ if importance_sampling is not None:
1704
+ importance_sampling = importance_sampling.lower()
1705
+
1706
+ if importance_sampling not in valid_importance_sampling:
1707
+ raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1708
+
1677
1709
  N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
1678
1710
 
1679
1711
  if maxcor is None:
@@ -19,7 +19,8 @@ from pymc.model.fgraph import (
19
19
  model_free_rv,
20
20
  model_from_fgraph,
21
21
  )
22
- from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
22
+ from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
23
+ from pymc.pytensorf import compile as compile_pymc
23
24
  from pymc.util import RandomState, _get_seeds_per_chain
24
25
  from pytensor import In, Out
25
26
  from pytensor.compile import SharedVariable
@@ -30,7 +30,7 @@ def compile_statespace(
30
30
 
31
31
  inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
32
32
 
33
- _f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
33
+ _f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34
34
 
35
35
  def f(*, draws=1, **params):
36
36
  if isinstance(steps, pt.Variable):
@@ -28,7 +28,6 @@ from pymc_extras.statespace.filters import (
28
28
  )
29
29
  from pymc_extras.statespace.filters.distributions import (
30
30
  LinearGaussianStateSpace,
31
- MvNormalSVD,
32
31
  SequenceMvNormal,
33
32
  )
34
33
  from pymc_extras.statespace.filters.utilities import stabilize
@@ -707,7 +706,7 @@ class PyMCStateSpace:
707
706
  with pymc_model:
708
707
  for param_name in self.param_names:
709
708
  param = getattr(pymc_model, param_name, None)
710
- if param:
709
+ if param is not None:
711
710
  found_params.append(param.name)
712
711
 
713
712
  missing_params = list(set(self.param_names) - set(found_params))
@@ -746,7 +745,7 @@ class PyMCStateSpace:
746
745
  with pymc_model:
747
746
  for data_name in data_names:
748
747
  data = getattr(pymc_model, data_name, None)
749
- if data:
748
+ if data is not None:
750
749
  found_data.append(data.name)
751
750
 
752
751
  missing_data = list(set(data_names) - set(found_data))
@@ -2233,7 +2232,9 @@ class PyMCStateSpace:
2233
2232
  if shock_trajectory is None:
2234
2233
  shock_trajectory = pt.zeros((n_steps, self.k_posdef))
2235
2234
  if Q is not None:
2236
- init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM])
2235
+ init_shock = pm.MvNormal(
2236
+ "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
2237
+ )
2237
2238
  else:
2238
2239
  init_shock = pm.Deterministic(
2239
2240
  "initial_shock",