pymc-extras 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/PKG-INFO +4 -3
  2. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/__init__.py +2 -0
  3. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/find_map.py +36 -16
  4. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/laplace.py +17 -10
  5. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/marginal/marginal_model.py +2 -1
  6. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/core/compile.py +1 -1
  7. pymc_extras-0.2.3/pymc_extras/version.txt +1 -0
  8. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras.egg-info/PKG-INFO +4 -3
  9. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras.egg-info/requires.txt +2 -1
  10. pymc_extras-0.2.3/requirements.txt +3 -0
  11. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/setup.py +1 -1
  12. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_find_map.py +19 -14
  13. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_laplace.py +42 -15
  14. pymc_extras-0.2.2/pymc_extras/version.txt +0 -1
  15. pymc_extras-0.2.2/requirements.txt +0 -2
  16. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/CODE_OF_CONDUCT.md +0 -0
  17. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/CONTRIBUTING.md +0 -0
  18. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/LICENSE +0 -0
  19. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/MANIFEST.in +0 -0
  20. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/README.md +0 -0
  21. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/__init__.py +0 -0
  22. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/continuous.py +0 -0
  23. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/discrete.py +0 -0
  24. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/histogram_utils.py +0 -0
  25. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  26. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
  27. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/distributions/timeseries.py +0 -0
  28. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/gp/__init__.py +0 -0
  29. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/gp/latent_approx.py +0 -0
  30. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/__init__.py +0 -0
  31. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/fit.py +0 -0
  32. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/pathfinder/__init__.py +0 -0
  33. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
  34. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
  35. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/pathfinder/pathfinder.py +0 -0
  36. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/smc/__init__.py +0 -0
  37. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/inference/smc/sampling.py +0 -0
  38. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/linearmodel.py +0 -0
  39. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/__init__.py +0 -0
  40. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/marginal/__init__.py +0 -0
  41. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/marginal/distributions.py +0 -0
  42. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/marginal/graph_analysis.py +0 -0
  43. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/model_api.py +0 -0
  44. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/transforms/__init__.py +0 -0
  45. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model/transforms/autoreparam.py +0 -0
  46. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/model_builder.py +0 -0
  47. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/preprocessing/__init__.py +0 -0
  48. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  49. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/printing.py +0 -0
  50. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/__init__.py +0 -0
  51. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/core/__init__.py +0 -0
  52. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/core/representation.py +0 -0
  53. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/core/statespace.py +0 -0
  54. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/__init__.py +0 -0
  55. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/distributions.py +0 -0
  56. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
  57. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
  58. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/filters/utilities.py +0 -0
  59. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/models/ETS.py +0 -0
  60. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/models/SARIMAX.py +0 -0
  61. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/models/VARMAX.py +0 -0
  62. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/models/__init__.py +0 -0
  63. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/models/structural.py +0 -0
  64. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/models/utilities.py +0 -0
  65. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/__init__.py +0 -0
  66. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/constants.py +0 -0
  67. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  68. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/statespace/utils/data_tools.py +0 -0
  69. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/utils/__init__.py +0 -0
  70. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/utils/linear_cg.py +0 -0
  71. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/utils/model_equivalence.py +0 -0
  72. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/utils/pivoted_cholesky.py +0 -0
  73. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/utils/prior.py +0 -0
  74. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/utils/spline.py +0 -0
  75. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras/version.py +0 -0
  76. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras.egg-info/SOURCES.txt +0 -0
  77. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras.egg-info/dependency_links.txt +0 -0
  78. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pymc_extras.egg-info/top_level.txt +0 -0
  79. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/pyproject.toml +0 -0
  80. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/requirements-dev.txt +0 -0
  81. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/requirements-docs.txt +0 -0
  82. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/setup.cfg +0 -0
  83. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/__init__.py +0 -0
  84. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/distributions/__init__.py +0 -0
  85. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/distributions/test_continuous.py +0 -0
  86. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/distributions/test_discrete.py +0 -0
  87. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/distributions/test_discrete_markov_chain.py +0 -0
  88. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/distributions/test_multivariate.py +0 -0
  89. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/model/__init__.py +0 -0
  90. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/model/marginal/__init__.py +0 -0
  91. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/model/marginal/test_distributions.py +0 -0
  92. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/model/marginal/test_graph_analysis.py +0 -0
  93. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/model/marginal/test_marginal_model.py +0 -0
  94. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/model/test_model_api.py +0 -0
  95. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/__init__.py +0 -0
  96. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_ETS.py +0 -0
  97. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_SARIMAX.py +0 -0
  98. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_VARMAX.py +0 -0
  99. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_coord_assignment.py +0 -0
  100. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_distributions.py +0 -0
  101. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_kalman_filter.py +0 -0
  102. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_representation.py +0 -0
  103. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_statespace.py +0 -0
  104. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_statespace_JAX.py +0 -0
  105. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/test_structural.py +0 -0
  106. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/utilities/__init__.py +0 -0
  107. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/utilities/shared_fixtures.py +0 -0
  108. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
  109. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/statespace/utilities/test_helpers.py +0 -0
  110. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_blackjax_smc.py +0 -0
  111. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_histogram_approximation.py +0 -0
  112. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_linearmodel.py +0 -0
  113. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_model_builder.py +0 -0
  114. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_pathfinder.py +0 -0
  115. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_pivoted_cholesky.py +0 -0
  116. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_printing.py +0 -0
  117. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_prior_from_trace.py +0 -0
  118. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/test_splines.py +0 -0
  119. {pymc_extras-0.2.2 → pymc_extras-0.2.3}/tests/utils.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
7
7
  Maintainer-email: pymc.devs@gmail.com
8
- License: Apache License, Version 2.0
8
+ License: Apache-2.0
9
9
  Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Programming Language :: Python
11
11
  Classifier: Programming Language :: Python :: 3
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.10
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.19.1
23
+ Requires-Dist: pymc>=5.20
24
24
  Requires-Dist: scikit-learn
25
+ Requires-Dist: better-optimize
25
26
  Provides-Extra: dask-histogram
26
27
  Requires-Dist: dask[complete]; extra == "dask-histogram"
27
28
  Requires-Dist: xhistogram; extra == "dask-histogram"
@@ -15,7 +15,9 @@ import logging
15
15
 
16
16
  from pymc_extras import gp, statespace, utils
17
17
  from pymc_extras.distributions import *
18
+ from pymc_extras.inference.find_map import find_MAP
18
19
  from pymc_extras.inference.fit import fit
20
+ from pymc_extras.inference.laplace import fit_laplace
19
21
  from pymc_extras.model.marginal.marginal_model import (
20
22
  MarginalModel,
21
23
  marginalize,
@@ -1,9 +1,9 @@
1
1
  import logging
2
2
 
3
3
  from collections.abc import Callable
4
+ from importlib.util import find_spec
4
5
  from typing import Literal, cast, get_args
5
6
 
6
- import jax
7
7
  import numpy as np
8
8
  import pymc as pm
9
9
  import pytensor
@@ -30,13 +30,29 @@ VALID_BACKENDS = get_args(GradientBackend)
30
30
  def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
31
31
  method_info = MINIMIZE_MODE_KWARGS[method].copy()
32
32
 
33
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34
- use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35
- use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36
-
37
33
  if use_hess and use_hessp:
34
+ _log.warning(
35
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37
+ 'Setting "use_hess" to False.'
38
+ )
38
39
  use_hess = False
39
40
 
41
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
42
+
43
+ if use_hessp is not None and use_hess is None:
44
+ use_hess = not use_hessp
45
+
46
+ elif use_hess is not None and use_hessp is None:
47
+ use_hessp = not use_hess
48
+
49
+ elif use_hessp is None and use_hess is None:
50
+ use_hessp = method_info["uses_hessp"]
51
+ use_hess = method_info["uses_hess"]
52
+ if use_hessp and use_hess:
53
+ # If a method could use either hess or hessp, we default to using hessp
54
+ use_hess = False
55
+
40
56
  return use_grad, use_hess, use_hessp
41
57
 
42
58
 
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
59
75
  The nearest positive semi-definite matrix to the input matrix.
60
76
  """
61
77
  C = (A + A.T) / 2
62
- eigval, eigvec = np.linalg.eig(C)
78
+ eigval, eigvec = np.linalg.eigh(C)
63
79
  eigval[eigval < 0] = 0
64
80
 
65
81
  return eigvec @ np.diag(eigval) @ eigvec.T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97
113
  return f_untransform(posterior_draws)
98
114
 
99
115
 
100
- def _compile_jax_gradients(
116
+ def _compile_grad_and_hess_to_jax(
101
117
  f_loss: Function, use_hess: bool, use_hessp: bool
102
118
  ) -> tuple[Callable | None, Callable | None]:
103
119
  """
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122
138
  f_hessp: Callable | None
123
139
  The compiled hessian-vector product function, or None if use_hessp is False.
124
140
  """
141
+ import jax
142
+
125
143
  f_hess = None
126
144
  f_hessp = None
127
145
 
@@ -152,7 +170,7 @@ def _compile_jax_gradients(
152
170
  return f_loss_and_grad, f_hess, f_hessp
153
171
 
154
172
 
155
- def _compile_functions(
173
+ def _compile_functions_for_scipy_optimize(
156
174
  loss: TensorVariable,
157
175
  inputs: list[TensorVariable],
158
176
  compute_grad: bool,
@@ -177,7 +195,7 @@ def _compile_functions(
177
195
  compute_hessp: bool
178
196
  Whether to compile a function that computes the Hessian-vector product of the loss function.
179
197
  compile_kwargs: dict, optional
180
- Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
198
+ Additional keyword arguments to pass to the ``pm.compile`` function.
181
199
 
182
200
  Returns
183
201
  -------
@@ -193,19 +211,19 @@ def _compile_functions(
193
211
  if compute_grad:
194
212
  grads = pytensor.gradient.grad(loss, inputs)
195
213
  grad = pt.concatenate([grad.ravel() for grad in grads])
196
- f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
214
+ f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197
215
  else:
198
- f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
216
+ f_loss = pm.compile(inputs, loss, **compile_kwargs)
199
217
  return [f_loss]
200
218
 
201
219
  if compute_hess:
202
220
  hess = pytensor.gradient.jacobian(grad, inputs)[0]
203
- f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
221
+ f_hess = pm.compile(inputs, hess, **compile_kwargs)
204
222
 
205
223
  if compute_hessp:
206
224
  p = pt.tensor("p", shape=inputs[0].type.shape)
207
225
  hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208
- f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
226
+ f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209
227
 
210
228
  return [f_loss_and_grad, f_hess, f_hessp]
211
229
 
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240
258
  gradient_backend: str, default "pytensor"
241
259
  Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242
260
  compile_kwargs:
243
- Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
261
+ Additional keyword arguments to pass to the ``pm.compile`` function.
244
262
 
245
263
  Returns
246
264
  -------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265
283
  )
266
284
 
267
285
  use_jax_gradients = (gradient_backend == "jax") and use_grad
286
+ if use_jax_gradients and not find_spec("jax"):
287
+ raise ImportError("JAX must be installed to use JAX gradients")
268
288
 
269
289
  mode = compile_kwargs.get("mode", None)
270
290
  if mode is None and use_jax_gradients:
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285
305
  compute_hess = use_hess and not use_jax_gradients
286
306
  compute_hessp = use_hessp and not use_jax_gradients
287
307
 
288
- funcs = _compile_functions(
308
+ funcs = _compile_functions_for_scipy_optimize(
289
309
  loss=loss,
290
310
  inputs=[flat_input],
291
311
  compute_grad=compute_grad,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301
321
 
302
322
  if use_jax_gradients:
303
323
  # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304
- f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
324
+ f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305
325
 
306
326
  return f_loss, f_hess, f_hessp
307
327
 
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from functools import reduce
19
+ from importlib.util import find_spec
19
20
  from itertools import product
20
21
  from typing import Literal
21
22
 
@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
231
232
  return idata
232
233
 
233
234
 
234
- def fit_mvn_to_MAP(
235
+ def fit_mvn_at_MAP(
235
236
  optimized_point: dict[str, np.ndarray],
236
237
  model: pm.Model | None = None,
237
238
  on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
276
277
  inverse_hessian: np.ndarray
277
278
  The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
278
279
  """
280
+ if gradient_backend == "jax" and not find_spec("jax"):
281
+ raise ImportError("JAX must be installed to use JAX gradients")
282
+
279
283
  model = pm.modelcontext(model)
280
284
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
281
285
  frozen_model = freeze_dims_and_data(model)
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
344
348
 
345
349
  Parameters
346
350
  ----------
347
- mu
348
- H_inv
351
+ mu: RaveledVars
352
+ The MAP estimate of the model parameters.
353
+ H_inv: np.ndarray
354
+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349
355
  model : Model
350
356
  A PyMC model
351
357
  chains : int
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
384
390
  constrained_rvs, replace={unconstrained_vector: batched_values}
385
391
  )
386
392
 
387
- f_constrain = pm.compile_pymc(
388
- inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389
- )
393
+ f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390
394
  posterior_draws = f_constrain(posterior_draws)
391
395
 
392
396
  else:
@@ -472,15 +476,17 @@ def fit_laplace(
472
476
  and 1).
473
477
 
474
478
  .. warning::
475
- This argumnet should be considered highly experimental. It has not been verified if this method produces
479
+ This argument should be considered highly experimental. It has not been verified if this method produces
476
480
  valid draws from the posterior. **Use at your own risk**.
477
481
 
478
482
  gradient_backend: str, default "pytensor"
479
483
  The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480
484
  chains: int, default: 2
481
- The number of sampling chains running in parallel.
485
+ The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
486
+ because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
487
+ compatible with the ArviZ library.
482
488
  draws: int, default: 500
483
- The number of samples to draw from the approximated posterior.
489
+ The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484
490
  on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485
491
  What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486
492
  If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,11 +553,12 @@ def fit_laplace(
547
553
  **optimizer_kwargs,
548
554
  )
549
555
 
550
- mu, H_inv = fit_mvn_to_MAP(
556
+ mu, H_inv = fit_mvn_at_MAP(
551
557
  optimized_point=optimized_point,
552
558
  model=model,
553
559
  on_bad_cov=on_bad_cov,
554
560
  transform_samples=fit_in_unconstrained_space,
561
+ gradient_backend=gradient_backend,
555
562
  zero_tol=zero_tol,
556
563
  diag_jitter=diag_jitter,
557
564
  compile_kwargs=compile_kwargs,
@@ -19,7 +19,8 @@ from pymc.model.fgraph import (
19
19
  model_free_rv,
20
20
  model_from_fgraph,
21
21
  )
22
- from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace
22
+ from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
23
+ from pymc.pytensorf import compile as compile_pymc
23
24
  from pymc.util import RandomState, _get_seeds_per_chain
24
25
  from pytensor import In, Out
25
26
  from pytensor.compile import SharedVariable
@@ -30,7 +30,7 @@ def compile_statespace(
30
30
 
31
31
  inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
32
32
 
33
- _f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
33
+ _f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34
34
 
35
35
  def f(*, draws=1, **params):
36
36
  if isinstance(steps, pt.Variable):
@@ -0,0 +1 @@
1
+ 0.2.3
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
7
7
  Maintainer-email: pymc.devs@gmail.com
8
- License: Apache License, Version 2.0
8
+ License: Apache-2.0
9
9
  Classifier: Development Status :: 5 - Production/Stable
10
10
  Classifier: Programming Language :: Python
11
11
  Classifier: Programming Language :: Python :: 3
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.10
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.19.1
23
+ Requires-Dist: pymc>=5.20
24
24
  Requires-Dist: scikit-learn
25
+ Requires-Dist: better-optimize
25
26
  Provides-Extra: dask-histogram
26
27
  Requires-Dist: dask[complete]; extra == "dask-histogram"
27
28
  Requires-Dist: xhistogram; extra == "dask-histogram"
@@ -1,5 +1,6 @@
1
- pymc>=5.19.1
1
+ pymc>=5.20
2
2
  scikit-learn
3
+ better-optimize
3
4
 
4
5
  [complete]
5
6
  dask[complete]
@@ -0,0 +1,3 @@
1
+ pymc>=5.20
2
+ scikit-learn
3
+ better-optimize
@@ -25,7 +25,7 @@ DESCRIPTION = "A home for new additions to PyMC, which may include unusual proba
25
25
  AUTHOR = "PyMC Developers"
26
26
  AUTHOR_EMAIL = "pymc.devs@gmail.com"
27
27
  URL = "http://github.com/pymc-devs/pymc-extras"
28
- LICENSE = "Apache License, Version 2.0"
28
+ LICENSE = "Apache-2.0"
29
29
 
30
30
  classifiers = [
31
31
  "Development Status :: 5 - Production/Stable",
@@ -54,24 +54,28 @@ def test_jax_functions_from_graph(gradient_backend: GradientBackend):
54
54
 
55
55
 
56
56
  @pytest.mark.parametrize(
57
- "method, use_grad, use_hess",
57
+ "method, use_grad, use_hess, use_hessp",
58
58
  [
59
- ("nelder-mead", False, False),
60
- ("powell", False, False),
61
- ("CG", True, False),
62
- ("BFGS", True, False),
63
- ("L-BFGS-B", True, False),
64
- ("TNC", True, False),
65
- ("SLSQP", True, False),
66
- ("dogleg", True, True),
67
- ("trust-ncg", True, True),
68
- ("trust-exact", True, True),
69
- ("trust-krylov", True, True),
70
- ("trust-constr", True, True),
59
+ ("nelder-mead", False, False, False),
60
+ ("powell", False, False, False),
61
+ ("CG", True, False, False),
62
+ ("BFGS", True, False, False),
63
+ ("L-BFGS-B", True, False, False),
64
+ ("TNC", True, False, False),
65
+ ("SLSQP", True, False, False),
66
+ ("dogleg", True, True, False),
67
+ ("Newton-CG", True, True, False),
68
+ ("Newton-CG", True, False, True),
69
+ ("trust-ncg", True, True, False),
70
+ ("trust-ncg", True, False, True),
71
+ ("trust-exact", True, True, False),
72
+ ("trust-krylov", True, True, False),
73
+ ("trust-krylov", True, False, True),
74
+ ("trust-constr", True, True, False),
71
75
  ],
72
76
  )
73
77
  @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
74
- def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
78
+ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
75
79
  extra_kwargs = {}
76
80
  if method == "dogleg":
77
81
  # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
88
92
  **extra_kwargs,
89
93
  use_grad=use_grad,
90
94
  use_hess=use_hess,
95
+ use_hessp=use_hessp,
91
96
  progressbar=False,
92
97
  gradient_backend=gradient_backend,
93
98
  compile_kwargs={"mode": "JAX"},
@@ -19,10 +19,10 @@ import pytest
19
19
 
20
20
  import pymc_extras as pmx
21
21
 
22
- from pymc_extras.inference.find_map import find_MAP
22
+ from pymc_extras.inference.find_map import GradientBackend, find_MAP
23
23
  from pymc_extras.inference.laplace import (
24
24
  fit_laplace,
25
- fit_mvn_to_MAP,
25
+ fit_mvn_at_MAP,
26
26
  sample_laplace_posterior,
27
27
  )
28
28
 
@@ -37,7 +37,11 @@ def rng():
37
37
  "ignore:hessian will stop negating the output in a future version of PyMC.\n"
38
38
  + "To suppress this warning set `negate_output=False`:FutureWarning",
39
39
  )
40
- def test_laplace():
40
+ @pytest.mark.parametrize(
41
+ "mode, gradient_backend",
42
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
43
+ )
44
+ def test_laplace(mode, gradient_backend: GradientBackend):
41
45
  # Example originates from Bayesian Data Analyses, 3rd Edition
42
46
  # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
43
47
  # Aki Vehtari, and Donald Rubin.
@@ -55,7 +59,13 @@ def test_laplace():
55
59
  vars = [mu, logsigma]
56
60
 
57
61
  idata = pmx.fit(
58
- method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1
62
+ method="laplace",
63
+ optimize_method="trust-ncg",
64
+ draws=draws,
65
+ random_seed=173300,
66
+ chains=1,
67
+ compile_kwargs={"mode": mode},
68
+ gradient_backend=gradient_backend,
59
69
  )
60
70
 
61
71
  assert idata.posterior["mu"].shape == (1, draws)
@@ -71,7 +81,11 @@ def test_laplace():
71
81
  np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
72
82
 
73
83
 
74
- def test_laplace_only_fit():
84
+ @pytest.mark.parametrize(
85
+ "mode, gradient_backend",
86
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
87
+ )
88
+ def test_laplace_only_fit(mode, gradient_backend: GradientBackend):
75
89
  # Example originates from Bayesian Data Analyses, 3rd Edition
76
90
  # By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
77
91
  # Aki Vehtari, and Donald Rubin.
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
90
104
  method="laplace",
91
105
  optimize_method="BFGS",
92
106
  progressbar=True,
93
- gradient_backend="jax",
94
- compile_kwargs={"mode": "JAX"},
107
+ gradient_backend=gradient_backend,
108
+ compile_kwargs={"mode": mode},
95
109
  optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
96
110
  random_seed=173300,
97
111
  )
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
111
125
  [True, False],
112
126
  ids=["transformed", "untransformed"],
113
127
  )
114
- @pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"])
115
- def test_fit_laplace_coords(rng, transform_samples, mode):
128
+ @pytest.mark.parametrize(
129
+ "mode, gradient_backend",
130
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
131
+ )
132
+ def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend):
116
133
  coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
117
134
  with pm.Model(coords=coords) as model:
118
135
  mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
@@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
131
148
  use_hessp=True,
132
149
  progressbar=False,
133
150
  compile_kwargs=dict(mode=mode),
134
- gradient_backend="jax" if mode == "JAX" else "pytensor",
151
+ gradient_backend=gradient_backend,
135
152
  )
136
153
 
137
154
  for value in optimized_point.values():
138
155
  assert value.shape == (3,)
139
156
 
140
- mu, H_inv = fit_mvn_to_MAP(
157
+ mu, H_inv = fit_mvn_at_MAP(
141
158
  optimized_point=optimized_point,
142
159
  model=model,
143
160
  transform_samples=transform_samples,
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
163
180
  ]
164
181
 
165
182
 
166
- def test_fit_laplace_ragged_coords(rng):
183
+ @pytest.mark.parametrize(
184
+ "mode, gradient_backend",
185
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
186
+ )
187
+ def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng):
167
188
  coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
168
189
  with pm.Model(coords=coords) as ragged_dim_model:
169
190
  X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
188
209
  progressbar=False,
189
210
  use_grad=True,
190
211
  use_hessp=True,
191
- gradient_backend="jax",
192
- compile_kwargs={"mode": "JAX"},
212
+ gradient_backend=gradient_backend,
213
+ compile_kwargs={"mode": mode},
193
214
  )
194
215
 
195
216
  assert idata["posterior"].beta.shape[-2:] == (3, 2)
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
206
227
  [True, False],
207
228
  ids=["transformed", "untransformed"],
208
229
  )
209
- def test_fit_laplace(fit_in_unconstrained_space):
230
+ @pytest.mark.parametrize(
231
+ "mode, gradient_backend",
232
+ [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
233
+ )
234
+ def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend):
210
235
  with pm.Model() as simp_model:
211
236
  mu = pm.Normal("mu", mu=3, sigma=0.5)
212
237
  sigma = pm.Exponential("sigma", 1)
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
223
248
  use_hessp=True,
224
249
  fit_in_unconstrained_space=fit_in_unconstrained_space,
225
250
  optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
251
+ compile_kwargs={"mode": mode},
252
+ gradient_backend=gradient_backend,
226
253
  )
227
254
 
228
255
  np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)
@@ -1 +0,0 @@
1
- 0.2.2
@@ -1,2 +0,0 @@
1
- pymc>=5.19.1
2
- scikit-learn
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes