pymc-extras 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/PKG-INFO +16 -4
  2. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/__init__.py +2 -0
  3. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/inference/find_map.py +36 -16
  4. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/inference/fit.py +0 -4
  5. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/inference/laplace.py +17 -10
  6. pymc_extras-0.2.3/pymc_extras/inference/pathfinder/__init__.py +3 -0
  7. pymc_extras-0.2.3/pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
  8. pymc_extras-0.2.3/pymc_extras/inference/pathfinder/lbfgs.py +190 -0
  9. pymc_extras-0.2.3/pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
  10. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/marginal/marginal_model.py +2 -1
  11. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/model_api.py +18 -2
  12. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/core/compile.py +1 -1
  13. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/core/statespace.py +79 -36
  14. pymc_extras-0.2.3/pymc_extras/version.txt +1 -0
  15. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras.egg-info/PKG-INFO +16 -4
  16. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras.egg-info/SOURCES.txt +4 -1
  17. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras.egg-info/requires.txt +2 -1
  18. pymc_extras-0.2.3/requirements.txt +3 -0
  19. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/setup.py +1 -1
  20. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/model/test_model_api.py +9 -0
  21. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_statespace.py +54 -0
  22. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_find_map.py +19 -14
  23. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_laplace.py +42 -15
  24. pymc_extras-0.2.3/tests/test_pathfinder.py +173 -0
  25. pymc_extras-0.2.1/pymc_extras/inference/pathfinder.py +0 -134
  26. pymc_extras-0.2.1/pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.1/requirements.txt +0 -2
  28. pymc_extras-0.2.1/tests/test_pathfinder.py +0 -45
  29. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/CODE_OF_CONDUCT.md +0 -0
  30. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/CONTRIBUTING.md +0 -0
  31. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/LICENSE +0 -0
  32. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/MANIFEST.in +0 -0
  33. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/README.md +0 -0
  34. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/__init__.py +0 -0
  35. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/continuous.py +0 -0
  36. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/discrete.py +0 -0
  37. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/histogram_utils.py +0 -0
  38. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  39. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
  40. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/distributions/timeseries.py +0 -0
  41. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/gp/__init__.py +0 -0
  42. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/gp/latent_approx.py +0 -0
  43. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/inference/__init__.py +0 -0
  44. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/inference/smc/__init__.py +0 -0
  45. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/inference/smc/sampling.py +0 -0
  46. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/linearmodel.py +0 -0
  47. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/__init__.py +0 -0
  48. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/marginal/__init__.py +0 -0
  49. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/marginal/distributions.py +0 -0
  50. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/marginal/graph_analysis.py +0 -0
  51. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/transforms/__init__.py +0 -0
  52. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model/transforms/autoreparam.py +0 -0
  53. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/model_builder.py +0 -0
  54. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/preprocessing/__init__.py +0 -0
  55. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  56. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/printing.py +0 -0
  57. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/__init__.py +0 -0
  58. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/core/__init__.py +0 -0
  59. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/core/representation.py +0 -0
  60. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/__init__.py +0 -0
  61. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/distributions.py +0 -0
  62. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
  63. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
  64. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/utilities.py +0 -0
  65. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/models/ETS.py +0 -0
  66. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/models/SARIMAX.py +0 -0
  67. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/models/VARMAX.py +0 -0
  68. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/models/__init__.py +0 -0
  69. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/models/structural.py +0 -0
  70. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/models/utilities.py +0 -0
  71. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/__init__.py +0 -0
  72. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/constants.py +0 -0
  73. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  74. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/data_tools.py +0 -0
  75. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/utils/__init__.py +0 -0
  76. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/utils/linear_cg.py +0 -0
  77. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/utils/model_equivalence.py +0 -0
  78. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/utils/pivoted_cholesky.py +0 -0
  79. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/utils/prior.py +0 -0
  80. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/utils/spline.py +0 -0
  81. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras/version.py +0 -0
  82. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras.egg-info/dependency_links.txt +0 -0
  83. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pymc_extras.egg-info/top_level.txt +0 -0
  84. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/pyproject.toml +0 -0
  85. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/requirements-dev.txt +0 -0
  86. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/requirements-docs.txt +0 -0
  87. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/setup.cfg +0 -0
  88. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/__init__.py +0 -0
  89. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/distributions/__init__.py +0 -0
  90. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/distributions/test_continuous.py +0 -0
  91. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/distributions/test_discrete.py +0 -0
  92. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/distributions/test_discrete_markov_chain.py +0 -0
  93. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/distributions/test_multivariate.py +0 -0
  94. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/model/__init__.py +0 -0
  95. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/model/marginal/__init__.py +0 -0
  96. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/model/marginal/test_distributions.py +0 -0
  97. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/model/marginal/test_graph_analysis.py +0 -0
  98. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/model/marginal/test_marginal_model.py +0 -0
  99. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/__init__.py +0 -0
  100. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_ETS.py +0 -0
  101. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_SARIMAX.py +0 -0
  102. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_VARMAX.py +0 -0
  103. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_coord_assignment.py +0 -0
  104. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_distributions.py +0 -0
  105. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_kalman_filter.py +0 -0
  106. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_representation.py +0 -0
  107. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_statespace_JAX.py +0 -0
  108. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/test_structural.py +0 -0
  109. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/utilities/__init__.py +0 -0
  110. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/utilities/shared_fixtures.py +0 -0
  111. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
  112. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/statespace/utilities/test_helpers.py +0 -0
  113. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_blackjax_smc.py +0 -0
  114. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_histogram_approximation.py +0 -0
  115. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_linearmodel.py +0 -0
  116. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_model_builder.py +0 -0
  117. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_pivoted_cholesky.py +0 -0
  118. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_printing.py +0 -0
  119. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_prior_from_trace.py +0 -0
  120. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/test_splines.py +0 -0
  121. {pymc_extras-0.2.1 → pymc_extras-0.2.3}/tests/utils.py +0 -0
@@ -1,11 +1,11 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
7
7
  Maintainer-email: pymc.devs@gmail.com
8
- License: Apache License, Version 2.0
8
+ License: Apache-2.0
9
9
  Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Programming Language :: Python
11
11
  Classifier: Programming Language :: Python :: 3
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.10
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.19.1
23
+ Requires-Dist: pymc>=5.20
24
24
  Requires-Dist: scikit-learn
25
+ Requires-Dist: better-optimize
25
26
  Provides-Extra: dask-histogram
26
27
  Requires-Dist: dask[complete]; extra == "dask-histogram"
27
28
  Requires-Dist: xhistogram; extra == "dask-histogram"
@@ -34,6 +35,17 @@ Provides-Extra: dev
34
35
  Requires-Dist: dask[all]; extra == "dev"
35
36
  Requires-Dist: blackjax; extra == "dev"
36
37
  Requires-Dist: statsmodels; extra == "dev"
38
+ Dynamic: classifier
39
+ Dynamic: description
40
+ Dynamic: description-content-type
41
+ Dynamic: home-page
42
+ Dynamic: license
43
+ Dynamic: maintainer
44
+ Dynamic: maintainer-email
45
+ Dynamic: provides-extra
46
+ Dynamic: requires-dist
47
+ Dynamic: requires-python
48
+ Dynamic: summary
37
49
 
38
50
  # Welcome to `pymc-extras`
39
51
  <a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
@@ -15,7 +15,9 @@ import logging
15
15
 
16
16
  from pymc_extras import gp, statespace, utils
17
17
  from pymc_extras.distributions import *
18
+ from pymc_extras.inference.find_map import find_MAP
18
19
  from pymc_extras.inference.fit import fit
20
+ from pymc_extras.inference.laplace import fit_laplace
19
21
  from pymc_extras.model.marginal.marginal_model import (
20
22
  MarginalModel,
21
23
  marginalize,
@@ -1,9 +1,9 @@
1
1
  import logging
2
2
 
3
3
  from collections.abc import Callable
4
+ from importlib.util import find_spec
4
5
  from typing import Literal, cast, get_args
5
6
 
6
- import jax
7
7
  import numpy as np
8
8
  import pymc as pm
9
9
  import pytensor
@@ -30,13 +30,29 @@ VALID_BACKENDS = get_args(GradientBackend)
30
30
  def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
31
31
  method_info = MINIMIZE_MODE_KWARGS[method].copy()
32
32
 
33
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34
- use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35
- use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36
-
37
33
  if use_hess and use_hessp:
34
+ _log.warning(
35
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37
+ 'Setting "use_hess" to False.'
38
+ )
38
39
  use_hess = False
39
40
 
41
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42
+
43
+ if use_hessp is not None and use_hess is None:
44
+ use_hess = not use_hessp
45
+
46
+ elif use_hess is not None and use_hessp is None:
47
+ use_hessp = not use_hess
48
+
49
+ elif use_hessp is None and use_hess is None:
50
+ use_hessp = method_info["uses_hessp"]
51
+ use_hess = method_info["uses_hess"]
52
+ if use_hessp and use_hess:
53
+ # If a method could use either hess or hessp, we default to using hessp
54
+ use_hess = False
55
+
40
56
  return use_grad, use_hess, use_hessp
41
57
 
42
58
 
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
59
75
  The nearest positive semi-definite matrix to the input matrix.
60
76
  """
61
77
  C = (A + A.T) / 2
62
- eigval, eigvec = np.linalg.eig(C)
78
+ eigval, eigvec = np.linalg.eigh(C)
63
79
  eigval[eigval < 0] = 0
64
80
 
65
81
  return eigvec @ np.diag(eigval) @ eigvec.T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97
113
  return f_untransform(posterior_draws)
98
114
 
99
115
 
100
- def _compile_jax_gradients(
116
+ def _compile_grad_and_hess_to_jax(
101
117
  f_loss: Function, use_hess: bool, use_hessp: bool
102
118
  ) -> tuple[Callable | None, Callable | None]:
103
119
  """
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122
138
  f_hessp: Callable | None
123
139
  The compiled hessian-vector product function, or None if use_hessp is False.
124
140
  """
141
+ import jax
142
+
125
143
  f_hess = None
126
144
  f_hessp = None
127
145
 
@@ -152,7 +170,7 @@ def _compile_jax_gradients(
152
170
  return f_loss_and_grad, f_hess, f_hessp
153
171
 
154
172
 
155
- def _compile_functions(
173
+ def _compile_functions_for_scipy_optimize(
156
174
  loss: TensorVariable,
157
175
  inputs: list[TensorVariable],
158
176
  compute_grad: bool,
@@ -177,7 +195,7 @@ def _compile_functions(
177
195
  compute_hessp: bool
178
196
  Whether to compile a function that computes the Hessian-vector product of the loss function.
179
197
  compile_kwargs: dict, optional
180
- Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
198
+ Additional keyword arguments to pass to the ``pm.compile`` function.
181
199
 
182
200
  Returns
183
201
  -------
@@ -193,19 +211,19 @@ def _compile_functions(
193
211
  if compute_grad:
194
212
  grads = pytensor.gradient.grad(loss, inputs)
195
213
  grad = pt.concatenate([grad.ravel() for grad in grads])
196
- f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
214
+ f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197
215
  else:
198
- f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
216
+ f_loss = pm.compile(inputs, loss, **compile_kwargs)
199
217
  return [f_loss]
200
218
 
201
219
  if compute_hess:
202
220
  hess = pytensor.gradient.jacobian(grad, inputs)[0]
203
- f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
221
+ f_hess = pm.compile(inputs, hess, **compile_kwargs)
204
222
 
205
223
  if compute_hessp:
206
224
  p = pt.tensor("p", shape=inputs[0].type.shape)
207
225
  hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208
- f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
226
+ f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209
227
 
210
228
  return [f_loss_and_grad, f_hess, f_hessp]
211
229
 
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240
258
  gradient_backend: str, default "pytensor"
241
259
  Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242
260
  compile_kwargs:
243
- Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
261
+ Additional keyword arguments to pass to the ``pm.compile`` function.
244
262
 
245
263
  Returns
246
264
  -------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265
283
  )
266
284
 
267
285
  use_jax_gradients = (gradient_backend == "jax") and use_grad
286
+ if use_jax_gradients and not find_spec("jax"):
287
+ raise ImportError("JAX must be installed to use JAX gradients")
268
288
 
269
289
  mode = compile_kwargs.get("mode", None)
270
290
  if mode is None and use_jax_gradients:
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285
305
  compute_hess = use_hess and not use_jax_gradients
286
306
  compute_hessp = use_hessp and not use_jax_gradients
287
307
 
288
- funcs = _compile_functions(
308
+ funcs = _compile_functions_for_scipy_optimize(
289
309
  loss=loss,
290
310
  inputs=[flat_input],
291
311
  compute_grad=compute_grad,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301
321
 
302
322
  if use_jax_gradients:
303
323
  # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304
- f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
324
+ f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305
325
 
306
326
  return f_loss, f_hess, f_hessp
307
327
 
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from importlib.util import find_spec
15
14
 
16
15
 
17
16
  def fit(method, **kwargs):
@@ -31,9 +30,6 @@ def fit(method, **kwargs):
31
30
  arviz.InferenceData
32
31
  """
33
32
  if method == "pathfinder":
34
- if find_spec("blackjax") is None:
35
- raise RuntimeError("Need BlackJAX to use `pathfinder`")
36
-
37
33
  from pymc_extras.inference.pathfinder import fit_pathfinder
38
34
 
39
35
  return fit_pathfinder(**kwargs)
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from functools import reduce
19
+ from importlib.util import find_spec
19
20
  from itertools import product
20
21
  from typing import Literal
21
22
 
@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
231
232
  return idata
232
233
 
233
234
 
234
- def fit_mvn_to_MAP(
235
+ def fit_mvn_at_MAP(
235
236
  optimized_point: dict[str, np.ndarray],
236
237
  model: pm.Model | None = None,
237
238
  on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
276
277
  inverse_hessian: np.ndarray
277
278
  The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
278
279
  """
280
+ if gradient_backend == "jax" and not find_spec("jax"):
281
+ raise ImportError("JAX must be installed to use JAX gradients")
282
+
279
283
  model = pm.modelcontext(model)
280
284
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
281
285
  frozen_model = freeze_dims_and_data(model)
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
344
348
 
345
349
  Parameters
346
350
  ----------
347
- mu
348
- H_inv
351
+ mu: RaveledVars
352
+ The MAP estimate of the model parameters.
353
+ H_inv: np.ndarray
354
+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349
355
  model : Model
350
356
  A PyMC model
351
357
  chains : int
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
384
390
  constrained_rvs, replace={unconstrained_vector: batched_values}
385
391
  )
386
392
 
387
- f_constrain = pm.compile_pymc(
388
- inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389
- )
393
+ f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390
394
  posterior_draws = f_constrain(posterior_draws)
391
395
 
392
396
  else:
@@ -472,15 +476,17 @@ def fit_laplace(
472
476
  and 1).
473
477
 
474
478
  .. warning::
475
- This argumnet should be considered highly experimental. It has not been verified if this method produces
479
+ This argument should be considered highly experimental. It has not been verified if this method produces
476
480
  valid draws from the posterior. **Use at your own risk**.
477
481
 
478
482
  gradient_backend: str, default "pytensor"
479
483
  The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480
484
  chains: int, default: 2
481
- The number of sampling chains running in parallel.
485
+ The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
486
+ because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
487
+ compatible with the ArviZ library.
482
488
  draws: int, default: 500
483
- The number of samples to draw from the approximated posterior.
489
+ The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484
490
  on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485
491
  What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486
492
  If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,11 +553,12 @@ def fit_laplace(
547
553
  **optimizer_kwargs,
548
554
  )
549
555
 
550
- mu, H_inv = fit_mvn_to_MAP(
556
+ mu, H_inv = fit_mvn_at_MAP(
551
557
  optimized_point=optimized_point,
552
558
  model=model,
553
559
  on_bad_cov=on_bad_cov,
554
560
  transform_samples=fit_in_unconstrained_space,
561
+ gradient_backend=gradient_backend,
555
562
  zero_tol=zero_tol,
556
563
  diag_jitter=diag_jitter,
557
564
  compile_kwargs=compile_kwargs,
@@ -0,0 +1,3 @@
1
+ from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
2
+
3
+ __all__ = ["fit_pathfinder"]
@@ -0,0 +1,139 @@
1
+ import logging
2
+ import warnings as _warnings
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Literal
6
+
7
+ import arviz as az
8
+ import numpy as np
9
+
10
+ from numpy.typing import NDArray
11
+ from scipy.special import logsumexp
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ImportanceSamplingResult:
18
+ """container for importance sampling results"""
19
+
20
+ samples: NDArray
21
+ pareto_k: float | None = None
22
+ warnings: list[str] = field(default_factory=list)
23
+ method: str = "none"
24
+
25
+
26
+ def importance_sampling(
27
+ samples: NDArray,
28
+ logP: NDArray,
29
+ logQ: NDArray,
30
+ num_draws: int,
31
+ method: Literal["psis", "psir", "identity", "none"] | None,
32
+ random_seed: int | None = None,
33
+ ) -> ImportanceSamplingResult:
34
+ """Pareto Smoothed Importance Resampling (PSIR)
35
+ This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
36
+
37
+ Parameters
38
+ ----------
39
+ samples : NDArray
40
+ samples from proposal distribution, shape (L, M, N)
41
+ logP : NDArray
42
+ log probability values of target distribution, shape (L, M)
43
+ logQ : NDArray
44
+ log probability values of proposal distribution, shape (L, M)
45
+ num_draws : int
46
+ number of draws to return where num_draws <= samples.shape[0]
47
+ method : str, optional
48
+ importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
49
+ random_seed : int | None
50
+
51
+ Returns
52
+ -------
53
+ ImportanceSamplingResult
54
+ importance sampled draws and other info based on the specified method
55
+
56
+ Future work!
57
+ ----------
58
+ - Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
59
+ - Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
60
+ - Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.
61
+
62
+ References
63
+ ----------
64
+ Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668
65
+
66
+ Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538
67
+
68
+ Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
69
+ """
70
+
71
+ warnings = []
72
+ num_paths, _, N = samples.shape
73
+
74
+ if method == "none":
75
+ warnings.append(
76
+ "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
77
+ )
78
+ return ImportanceSamplingResult(samples=samples, warnings=warnings)
79
+ else:
80
+ samples = samples.reshape(-1, N)
81
+ logP = logP.ravel()
82
+ logQ = logQ.ravel()
83
+
84
+ # adjust log densities
85
+ log_I = np.log(num_paths)
86
+ logP -= log_I
87
+ logQ -= log_I
88
+ logiw = logP - logQ
89
+
90
+ with _warnings.catch_warnings():
91
+ _warnings.filterwarnings(
92
+ "ignore", category=RuntimeWarning, message="overflow encountered in exp"
93
+ )
94
+ if method == "psis":
95
+ replace = False
96
+ logiw, pareto_k = az.psislw(logiw)
97
+ elif method == "psir":
98
+ replace = True
99
+ logiw, pareto_k = az.psislw(logiw)
100
+ elif method == "identity":
101
+ replace = False
102
+ pareto_k = None
103
+ else:
104
+ raise ValueError(f"Invalid importance sampling method: {method}")
105
+
106
+ # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107
+ # Pareto k may not be a good diagnostic for Pathfinder.
108
+ # TODO: Find replacement diagnostics for Pathfinder.
109
+
110
+ p = np.exp(logiw - logsumexp(logiw))
111
+ rng = np.random.default_rng(random_seed)
112
+
113
+ try:
114
+ resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
115
+ return ImportanceSamplingResult(
116
+ samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
117
+ )
118
+ except ValueError as e1:
119
+ if "Fewer non-zero entries in p than size" in str(e1):
120
+ num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
121
+ warnings.append(
122
+ f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
123
+ )
124
+ try:
125
+ resampled = rng.choice(
126
+ samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
127
+ )
128
+ return ImportanceSamplingResult(
129
+ samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
130
+ )
131
+ except ValueError as e2:
132
+ logger.error(
133
+ "Importance sampling failed even with psir importance sampling. "
134
+ "This might indicate invalid probability weights or insufficient valid samples."
135
+ )
136
+ raise ValueError(
137
+ "Importance sampling failed for both with and without replacement"
138
+ ) from e2
139
+ raise
@@ -0,0 +1,190 @@
1
+ import logging
2
+
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum, auto
6
+
7
+ import numpy as np
8
+
9
+ from numpy.typing import NDArray
10
+ from scipy.optimize import minimize
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass(slots=True)
16
+ class LBFGSHistory:
17
+ """History of LBFGS iterations."""
18
+
19
+ x: NDArray[np.float64]
20
+ g: NDArray[np.float64]
21
+ count: int
22
+
23
+ def __post_init__(self):
24
+ self.x = np.ascontiguousarray(self.x, dtype=np.float64)
25
+ self.g = np.ascontiguousarray(self.g, dtype=np.float64)
26
+
27
+
28
+ @dataclass(slots=True)
29
+ class LBFGSHistoryManager:
30
+ """manages and stores the history of lbfgs optimisation iterations.
31
+
32
+ Parameters
33
+ ----------
34
+ value_grad_fn : Callable
35
+ function that returns tuple of (value, gradient) given input x
36
+ x0 : NDArray
37
+ initial position
38
+ maxiter : int
39
+ maximum number of iterations to store
40
+ """
41
+
42
+ value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
43
+ x0: NDArray[np.float64]
44
+ maxiter: int
45
+ x_history: NDArray[np.float64] = field(init=False)
46
+ g_history: NDArray[np.float64] = field(init=False)
47
+ count: int = field(init=False)
48
+
49
+ def __post_init__(self) -> None:
50
+ self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
51
+ self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
52
+ self.count = 0
53
+
54
+ value, grad = self.value_grad_fn(self.x0)
55
+ if np.all(np.isfinite(grad)) and np.isfinite(value):
56
+ self.add_entry(self.x0, grad)
57
+
58
+ def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
59
+ """adds new position and gradient to history.
60
+
61
+ Parameters
62
+ ----------
63
+ x : NDArray
64
+ position vector
65
+ g : NDArray
66
+ gradient vector
67
+ """
68
+ self.x_history[self.count] = x
69
+ self.g_history[self.count] = g
70
+ self.count += 1
71
+
72
+ def get_history(self) -> LBFGSHistory:
73
+ """returns history of optimisation iterations."""
74
+ return LBFGSHistory(
75
+ x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
76
+ )
77
+
78
+ def __call__(self, x: NDArray[np.float64]) -> None:
79
+ value, grad = self.value_grad_fn(x)
80
+ if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
81
+ self.add_entry(x, grad)
82
+
83
+
84
+ class LBFGSStatus(Enum):
85
+ CONVERGED = auto()
86
+ MAX_ITER_REACHED = auto()
87
+ DIVERGED = auto()
88
+ # Statuses that lead to Exceptions:
89
+ INIT_FAILED = auto()
90
+ LBFGS_FAILED = auto()
91
+
92
+
93
+ class LBFGSException(Exception):
94
+ DEFAULT_MESSAGE = "LBFGS failed."
95
+
96
+ def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED):
97
+ super().__init__(message or self.DEFAULT_MESSAGE)
98
+ self.status = status
99
+
100
+
101
+ class LBFGSInitFailed(LBFGSException):
102
+ DEFAULT_MESSAGE = "LBFGS failed to initialise."
103
+
104
+ def __init__(self, message=None):
105
+ super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
106
+
107
+
108
+ class LBFGS:
109
+ """L-BFGS optimizer wrapper around scipy's implementation.
110
+
111
+ Parameters
112
+ ----------
113
+ value_grad_fn : Callable
114
+ function that returns tuple of (value, gradient) given input x
115
+ maxcor : int
116
+ maximum number of variable metric corrections
117
+ maxiter : int, optional
118
+ maximum number of iterations, defaults to 1000
119
+ ftol : float, optional
120
+ function tolerance for convergence, defaults to 1e-5
121
+ gtol : float, optional
122
+ gradient tolerance for convergence, defaults to 1e-8
123
+ maxls : int, optional
124
+ maximum number of line search steps, defaults to 1000
125
+ """
126
+
127
+ def __init__(
128
+ self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
129
+ ) -> None:
130
+ self.value_grad_fn = value_grad_fn
131
+ self.maxcor = maxcor
132
+ self.maxiter = maxiter
133
+ self.ftol = ftol
134
+ self.gtol = gtol
135
+ self.maxls = maxls
136
+
137
+ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
138
+ """minimizes objective function starting from initial position.
139
+
140
+ Parameters
141
+ ----------
142
+ x0 : array_like
143
+ initial position
144
+
145
+ Returns
146
+ -------
147
+ x : NDArray
148
+ history of positions
149
+ g : NDArray
150
+ history of gradients
151
+ count : int
152
+ number of iterations
153
+ status : LBFGSStatus
154
+ final status of optimisation
155
+ """
156
+
157
+ x0 = np.array(x0, dtype=np.float64)
158
+
159
+ history_manager = LBFGSHistoryManager(
160
+ value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
161
+ )
162
+
163
+ result = minimize(
164
+ self.value_grad_fn,
165
+ x0,
166
+ method="L-BFGS-B",
167
+ jac=True,
168
+ callback=history_manager,
169
+ options={
170
+ "maxcor": self.maxcor,
171
+ "maxiter": self.maxiter,
172
+ "ftol": self.ftol,
173
+ "gtol": self.gtol,
174
+ "maxls": self.maxls,
175
+ },
176
+ )
177
+ history = history_manager.get_history()
178
+
179
+ # warnings and suggestions for LBFGSStatus are displayed at the end
180
+ if result.status == 1:
181
+ lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
182
+ elif (result.status == 2) or (history.count <= 1):
183
+ if result.nit <= 1:
184
+ lbfgs_status = LBFGSStatus.INIT_FAILED
185
+ elif result.fun == np.inf:
186
+ lbfgs_status = LBFGSStatus.DIVERGED
187
+ else:
188
+ lbfgs_status = LBFGSStatus.CONVERGED
189
+
190
+ return history.x, history.g, history.count, lbfgs_status