pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/deserialize.py +224 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/inference/find_map.py +62 -17
  6. pymc_extras/inference/laplace.py +10 -7
  7. pymc_extras/prior.py +1356 -0
  8. pymc_extras/statespace/core/statespace.py +191 -52
  9. pymc_extras/statespace/filters/distributions.py +15 -16
  10. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  11. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  12. pymc_extras/statespace/models/ETS.py +10 -0
  13. pymc_extras/statespace/models/SARIMAX.py +26 -5
  14. pymc_extras/statespace/models/VARMAX.py +12 -2
  15. pymc_extras/statespace/models/structural.py +18 -5
  16. pymc_extras-0.2.7.dist-info/METADATA +321 -0
  17. pymc_extras-0.2.7.dist-info/RECORD +66 -0
  18. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
  19. pymc_extras/utils/pivoted_cholesky.py +0 -69
  20. pymc_extras/version.py +0 -11
  21. pymc_extras/version.txt +0 -1
  22. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  23. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  24. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  25. tests/__init__.py +0 -13
  26. tests/distributions/__init__.py +0 -19
  27. tests/distributions/test_continuous.py +0 -185
  28. tests/distributions/test_discrete.py +0 -210
  29. tests/distributions/test_discrete_markov_chain.py +0 -258
  30. tests/distributions/test_multivariate.py +0 -304
  31. tests/distributions/test_transform.py +0 -77
  32. tests/model/__init__.py +0 -0
  33. tests/model/marginal/__init__.py +0 -0
  34. tests/model/marginal/test_distributions.py +0 -132
  35. tests/model/marginal/test_graph_analysis.py +0 -182
  36. tests/model/marginal/test_marginal_model.py +0 -967
  37. tests/model/test_model_api.py +0 -38
  38. tests/statespace/__init__.py +0 -0
  39. tests/statespace/test_ETS.py +0 -411
  40. tests/statespace/test_SARIMAX.py +0 -405
  41. tests/statespace/test_VARMAX.py +0 -184
  42. tests/statespace/test_coord_assignment.py +0 -181
  43. tests/statespace/test_distributions.py +0 -270
  44. tests/statespace/test_kalman_filter.py +0 -326
  45. tests/statespace/test_representation.py +0 -175
  46. tests/statespace/test_statespace.py +0 -872
  47. tests/statespace/test_statespace_JAX.py +0 -156
  48. tests/statespace/test_structural.py +0 -836
  49. tests/statespace/utilities/__init__.py +0 -0
  50. tests/statespace/utilities/shared_fixtures.py +0 -9
  51. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  52. tests/statespace/utilities/test_helpers.py +0 -310
  53. tests/test_blackjax_smc.py +0 -222
  54. tests/test_find_map.py +0 -103
  55. tests/test_histogram_approximation.py +0 -109
  56. tests/test_laplace.py +0 -281
  57. tests/test_linearmodel.py +0 -208
  58. tests/test_model_builder.py +0 -306
  59. tests/test_pathfinder.py +0 -297
  60. tests/test_pivoted_cholesky.py +0 -24
  61. tests/test_printing.py +0 -98
  62. tests/test_prior_from_trace.py +0 -172
  63. tests/test_splines.py +0 -77
  64. tests/utils.py +0 -0
  65. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
pymc_extras/__init__.py CHANGED
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  import logging
15
15
 
16
+ from importlib.metadata import version
17
+
16
18
  from pymc_extras import gp, statespace, utils
17
19
  from pymc_extras.distributions import *
18
20
  from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
@@ -22,7 +24,6 @@ from pymc_extras.model.marginal.marginal_model import (
22
24
  recover_marginals,
23
25
  )
24
26
  from pymc_extras.model.model_api import as_model
25
- from pymc_extras.version import __version__
26
27
 
27
28
  _log = logging.getLogger("pmx")
28
29
 
@@ -31,3 +32,6 @@ if not logging.root.handlers:
31
32
  if len(_log.handlers) == 0:
32
33
  handler = logging.StreamHandler()
33
34
  _log.addHandler(handler)
35
+
36
+
37
+ __version__ = version("pymc-extras")
@@ -0,0 +1,224 @@
1
+ """Deserialize dictionaries into Python objects.
2
+
3
+ This is a two step process:
4
+
5
+ 1. Determine if the data is of the correct type.
6
+ 2. Deserialize the data into a python object.
7
+
8
+ Examples
9
+ --------
10
+ Make use of the already registered deserializers:
11
+
12
+ .. code-block:: python
13
+
14
+ from pymc_extras.deserialize import deserialize
15
+
16
+ prior_class_data = {
17
+ "dist": "Normal",
18
+ "kwargs": {"mu": 0, "sigma": 1}
19
+ }
20
+ prior = deserialize(prior_class_data)
21
+ # Prior("Normal", mu=0, sigma=1)
22
+
23
+ Register custom class deserialization:
24
+
25
+ .. code-block:: python
26
+
27
+ from pymc_extras.deserialize import register_deserialization
28
+
29
+ class MyClass:
30
+ def __init__(self, value: int):
31
+ self.value = value
32
+
33
+ def to_dict(self) -> dict:
34
+ # Example of what the to_dict method might look like.
35
+ return {"value": self.value}
36
+
37
+ register_deserialization(
38
+ is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39
+ deserialize=lambda data: MyClass(value=data["value"]),
40
+ )
41
+
42
+ Deserialize data into that custom class:
43
+
44
+ .. code-block:: python
45
+
46
+ from pymc_extras.deserialize import deserialize
47
+
48
+ data = {"value": 42}
49
+ obj = deserialize(data)
50
+ assert isinstance(obj, MyClass)
51
+
52
+
53
+ """
54
+
55
+ from collections.abc import Callable
56
+ from dataclasses import dataclass
57
+ from typing import Any
58
+
59
+ IsType = Callable[[Any], bool]
60
+ Deserialize = Callable[[Any], Any]
61
+
62
+
63
+ @dataclass
64
+ class Deserializer:
65
+ """Object to store information required for deserialization.
66
+
67
+ All deserializers should be stored via the :func:`register_deserialization` function
68
+ instead of creating this object directly.
69
+
70
+ Attributes
71
+ ----------
72
+ is_type : IsType
73
+ Function to determine if the data is of the correct type.
74
+ deserialize : Deserialize
75
+ Function to deserialize the data.
76
+
77
+ Examples
78
+ --------
79
+ .. code-block:: python
80
+
81
+ from typing import Any
82
+
83
+ class MyClass:
84
+ def __init__(self, value: int):
85
+ self.value = value
86
+
87
+ from pymc_extras.deserialize import Deserializer
88
+
89
+ def is_type(data: Any) -> bool:
90
+ return data.keys() == {"value"} and isinstance(data["value"], int)
91
+
92
+ def deserialize(data: dict) -> MyClass:
93
+ return MyClass(value=data["value"])
94
+
95
+ deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96
+
97
+ """
98
+
99
+ is_type: IsType
100
+ deserialize: Deserialize
101
+
102
+
103
+ DESERIALIZERS: list[Deserializer] = []
104
+
105
+
106
+ class DeserializableError(Exception):
107
+ """Error raised when data cannot be deserialized."""
108
+
109
+ def __init__(self, data: Any):
110
+ self.data = data
111
+ super().__init__(
112
+ f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping."
113
+ )
114
+
115
+
116
+ def deserialize(data: Any) -> Any:
117
+ """Deserialize a dictionary into a Python object.
118
+
119
+ Use the :func:`register_deserialization` function to add custom deserializations.
120
+
121
+ Deserialization is a two step process due to the dynamic nature of the data:
122
+
123
+ 1. Determine if the data is of the correct type.
124
+ 2. Deserialize the data into a Python object.
125
+
126
+ Each registered deserialization is checked in order until one is found that can
127
+ deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised.
128
+
129
+ A :class:`DeserializableError` is raised when the data fails to be deserialized
130
+ by any of the registered deserializers.
131
+
132
+ Parameters
133
+ ----------
134
+ data : Any
135
+ The data to deserialize.
136
+
137
+ Returns
138
+ -------
139
+ Any
140
+ The deserialized object.
141
+
142
+ Raises
143
+ ------
144
+ DeserializableError
145
+ Raised when the data doesn't match any registered deserializations
146
+ or fails to be deserialized.
147
+
148
+ Examples
149
+ --------
150
+ Deserialize a :class:`pymc_extras.prior.Prior` object:
151
+
152
+ .. code-block:: python
153
+
154
+ from pymc_extras.deserialize import deserialize
155
+
156
+ data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
157
+ prior = deserialize(data)
158
+ # Prior("Normal", mu=0, sigma=1)
159
+
160
+ """
161
+ for mapping in DESERIALIZERS:
162
+ try:
163
+ is_type = mapping.is_type(data)
164
+ except Exception:
165
+ is_type = False
166
+
167
+ if not is_type:
168
+ continue
169
+
170
+ try:
171
+ return mapping.deserialize(data)
172
+ except Exception as e:
173
+ raise DeserializableError(data) from e
174
+ else:
175
+ raise DeserializableError(data)
176
+
177
+
178
+ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
179
+ """Register an arbitrary deserialization.
180
+
181
+ Use the :func:`deserialize` function to then deserialize data using all registered
182
+ deserialize functions.
183
+
184
+ Parameters
185
+ ----------
186
+ is_type : Callable[[Any], bool]
187
+ Function to determine if the data is of the correct type.
188
+ deserialize : Callable[[dict], Any]
189
+ Function to deserialize the data of that type.
190
+
191
+ Examples
192
+ --------
193
+ Register a custom class deserialization:
194
+
195
+ .. code-block:: python
196
+
197
+ from pymc_extras.deserialize import register_deserialization
198
+
199
+ class MyClass:
200
+ def __init__(self, value: int):
201
+ self.value = value
202
+
203
+ def to_dict(self) -> dict:
204
+ # Example of what the to_dict method might look like.
205
+ return {"value": self.value}
206
+
207
+ register_deserialization(
208
+ is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209
+ deserialize=lambda data: MyClass(value=data["value"]),
210
+ )
211
+
212
+ Use that custom class deserialization:
213
+
214
+ .. code-block:: python
215
+
216
+ from pymc_extras.deserialize import deserialize
217
+
218
+ data = {"value": 42}
219
+ obj = deserialize(data)
220
+ assert isinstance(obj, MyClass)
221
+
222
+ """
223
+ mapping = Deserializer(is_type=is_type, deserialize=deserialize)
224
+ DESERIALIZERS.append(mapping)
@@ -81,7 +81,7 @@ class GenExtreme(Continuous):
81
81
 
82
82
  \left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}.
83
83
 
84
- Note that this parametrization is per Coles (2001), and differs from that of
84
+ Note that this parametrization is per Coles (2001) [1]_, and differs from that of
85
85
  Scipy in the sign of the shape parameter, :math:`\xi`.
86
86
 
87
87
  .. plot::
@@ -132,7 +132,7 @@ class GenExtreme(Continuous):
132
132
 
133
133
  References
134
134
  ----------
135
- .. [Coles2001] Coles, S.G. (2001).
135
+ .. [1] Coles, S.G. (2001).
136
136
  An Introduction to the Statistical Modeling of Extreme Values
137
137
  Springer-Verlag, London
138
138
 
@@ -260,6 +260,7 @@ class Chi:
260
260
  Examples
261
261
  --------
262
262
  .. code-block:: python
263
+
263
264
  import pymc as pm
264
265
  from pymc_extras.distributions import Chi
265
266
 
@@ -116,6 +116,7 @@ class GeneralizedPoisson(pm.distributions.Discrete):
116
116
 
117
117
  .. math:: f(x \mid \mu, \lambda) =
118
118
  \frac{\mu (\mu + \lambda x)^{x-1} e^{-\mu - \lambda x}}{x!}
119
+
119
120
  ======== ======================================
120
121
  Support :math:`x \in \mathbb{N}_0`
121
122
  Mean :math:`\frac{\mu}{1 - \lambda}`
@@ -135,9 +136,10 @@ class GeneralizedPoisson(pm.distributions.Discrete):
135
136
  When lam < 0, the mean is greater than the variance (underdispersion).
136
137
  When lam > 0, the mean is less than the variance (overdispersion).
137
138
 
139
+ The PMF is taken from [1]_ and the random generator function is adapted from [2]_.
140
+
138
141
  References
139
142
  ----------
140
- The PMF is taken from [1] and the random generator function is adapted from [2].
141
143
  .. [1] Consul, PoC, and Felix Famoye. "Generalized Poisson regression model."
142
144
  Communications in Statistics-Theory and Methods 21.1 (1992): 89-109.
143
145
  .. [2] Famoye, Felix. "Generalized Poisson random variate generation." American
@@ -9,7 +9,7 @@ import pymc as pm
9
9
  import pytensor
10
10
  import pytensor.tensor as pt
11
11
 
12
- from better_optimize import minimize
12
+ from better_optimize import basinhopping, minimize
13
13
  from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
14
14
  from pymc.blocking import DictToArrayBijection, RaveledVars
15
15
  from pymc.initial_point import make_initial_point_fn
@@ -146,7 +146,7 @@ def _compile_grad_and_hess_to_jax(
146
146
  orig_loss_fn = f_loss.vm.jit_fn
147
147
 
148
148
  @jax.jit
149
- def loss_fn_jax_grad(x, *shared):
149
+ def loss_fn_jax_grad(x):
150
150
  return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
151
151
 
152
152
  f_loss_and_grad = loss_fn_jax_grad
@@ -301,6 +301,14 @@ def scipy_optimize_funcs_from_loss(
301
301
  point=initial_point_dict, outputs=[loss], inputs=inputs
302
302
  )
303
303
 
304
+ # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
305
+ # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
306
+ # away.
307
+ if use_jax_gradients:
308
+ from pymc.sampling.jax import _replace_shared_variables
309
+
310
+ [loss] = _replace_shared_variables([loss])
311
+
304
312
  compute_grad = use_grad and not use_jax_gradients
305
313
  compute_hess = use_hess and not use_jax_gradients
306
314
  compute_hessp = use_hessp and not use_jax_gradients
@@ -327,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
327
335
 
328
336
 
329
337
  def find_MAP(
330
- method: minimize_method,
338
+ method: minimize_method | Literal["basinhopping"],
331
339
  *,
332
340
  model: pm.Model | None = None,
333
341
  use_grad: bool | None = None,
@@ -344,14 +352,17 @@ def find_MAP(
344
352
  **optimizer_kwargs,
345
353
  ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
346
354
  """
347
- Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
355
+ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
348
356
 
349
357
  Parameters
350
358
  ----------
351
359
  model : pm.Model
352
360
  The PyMC model to be fit. If None, the current model context is used.
353
361
  method : str
354
- The optimization method to use. See scipy.optimize.minimize documentation for details.
362
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364
+
365
+ See scipy.optimize.minimize documentation for details.
355
366
  use_grad : bool | None, optional
356
367
  Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
357
368
  the ``method``.
@@ -379,7 +390,9 @@ def find_MAP(
379
390
  compile_kwargs: dict, optional
380
391
  Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
381
392
  **optimizer_kwargs
382
- Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
393
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
383
396
 
384
397
  Returns
385
398
  -------
@@ -405,6 +418,18 @@ def find_MAP(
405
418
  initial_params = DictToArrayBijection.map(
406
419
  {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
407
420
  )
421
+
422
+ do_basinhopping = method == "basinhopping"
423
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
424
+
425
+ if do_basinhopping:
426
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428
+ # if one isn't provided.
429
+
430
+ method = minimizer_kwargs.pop("method", "L-BFGS-B")
431
+ minimizer_kwargs["method"] = method
432
+
408
433
  use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
409
434
  method, use_grad, use_hess, use_hessp
410
435
  )
@@ -423,17 +448,37 @@ def find_MAP(
423
448
  args = optimizer_kwargs.pop("args", None)
424
449
 
425
450
  # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
426
- # if so. That is why it is not set here, regardless of user settings.
427
- optimizer_result = minimize(
428
- f=f_logp,
429
- x0=cast(np.ndarray[float], initial_params.data),
430
- args=args,
431
- hess=f_hess,
432
- hessp=f_hessp,
433
- progressbar=progressbar,
434
- method=method,
435
- **optimizer_kwargs,
436
- )
451
+ # if so. That is why the jac argument is not passed here in either branch.
452
+
453
+ if do_basinhopping:
454
+ if "args" not in minimizer_kwargs:
455
+ minimizer_kwargs["args"] = args
456
+ if "hess" not in minimizer_kwargs:
457
+ minimizer_kwargs["hess"] = f_hess
458
+ if "hessp" not in minimizer_kwargs:
459
+ minimizer_kwargs["hessp"] = f_hessp
460
+ if "method" not in minimizer_kwargs:
461
+ minimizer_kwargs["method"] = method
462
+
463
+ optimizer_result = basinhopping(
464
+ func=f_logp,
465
+ x0=cast(np.ndarray[float], initial_params.data),
466
+ progressbar=progressbar,
467
+ minimizer_kwargs=minimizer_kwargs,
468
+ **optimizer_kwargs,
469
+ )
470
+
471
+ else:
472
+ optimizer_result = minimize(
473
+ f=f_logp,
474
+ x0=cast(np.ndarray[float], initial_params.data),
475
+ args=args,
476
+ hess=f_hess,
477
+ hessp=f_hessp,
478
+ progressbar=progressbar,
479
+ method=method,
480
+ **optimizer_kwargs,
481
+ )
437
482
 
438
483
  raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
439
484
  unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
@@ -416,7 +416,7 @@ def sample_laplace_posterior(
416
416
 
417
417
 
418
418
  def fit_laplace(
419
- optimize_method: minimize_method = "BFGS",
419
+ optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
420
420
  *,
421
421
  model: pm.Model | None = None,
422
422
  use_grad: bool | None = None,
@@ -449,8 +449,11 @@ def fit_laplace(
449
449
  ----------
450
450
  model : pm.Model
451
451
  The PyMC model to be fit. If None, the current model context is used.
452
- optimize_method : str
453
- The optimization method to use. See scipy.optimize.minimize documentation for details.
452
+ method : str
453
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
454
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
455
+
456
+ See scipy.optimize.minimize documentation for details.
454
457
  use_grad : bool | None, optional
455
458
  Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
456
459
  the ``method``.
@@ -500,10 +503,10 @@ def fit_laplace(
500
503
  diag_jitter: float | None
501
504
  A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
502
505
  If None, no jitter is added. Default is 1e-8.
503
- optimizer_kwargs: dict, optional
504
- Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
505
- details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
506
- to use a nested dictionary.
506
+ optimizer_kwargs
507
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
508
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
509
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
507
510
  compile_kwargs: dict, optional
508
511
  Additional keyword arguments to pass to pytensor.function.
509
512