pymc-extras 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymc_extras/__init__.py CHANGED
@@ -16,7 +16,11 @@ import logging
16
16
  from pymc_extras import gp, statespace, utils
17
17
  from pymc_extras.distributions import *
18
18
  from pymc_extras.inference.fit import fit
19
- from pymc_extras.model.marginal.marginal_model import MarginalModel, marginalize
19
+ from pymc_extras.model.marginal.marginal_model import (
20
+ MarginalModel,
21
+ marginalize,
22
+ recover_marginals,
23
+ )
20
24
  from pymc_extras.model.model_api import as_model
21
25
  from pymc_extras.version import __version__
22
26
 
@@ -214,8 +214,8 @@ class DiscreteMarkovChain(Distribution):
214
214
  discrete_mc_op = DiscreteMarkovChainRV(
215
215
  inputs=[P_, steps_, init_dist_, state_rng],
216
216
  outputs=[state_next_rng, discrete_mc_],
217
- ndim_supp=1,
218
217
  n_lags=n_lags,
218
+ extended_signature="(p,p),(),(p),[rng]->[rng],(t)",
219
219
  )
220
220
 
221
221
  discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng)
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from importlib.util import find_spec
15
14
 
16
15
 
17
16
  def fit(method, **kwargs):
@@ -31,9 +30,6 @@ def fit(method, **kwargs):
31
30
  arviz.InferenceData
32
31
  """
33
32
  if method == "pathfinder":
34
- if find_spec("blackjax") is None:
35
- raise RuntimeError("Need BlackJAX to use `pathfinder`")
36
-
37
33
  from pymc_extras.inference.pathfinder import fit_pathfinder
38
34
 
39
35
  return fit_pathfinder(**kwargs)
@@ -0,0 +1,3 @@
1
+ from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
2
+
3
+ __all__ = ["fit_pathfinder"]
@@ -0,0 +1,139 @@
1
+ import logging
2
+ import warnings as _warnings
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Literal
6
+
7
+ import arviz as az
8
+ import numpy as np
9
+
10
+ from numpy.typing import NDArray
11
+ from scipy.special import logsumexp
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ImportanceSamplingResult:
18
+ """container for importance sampling results"""
19
+
20
+ samples: NDArray
21
+ pareto_k: float | None = None
22
+ warnings: list[str] = field(default_factory=list)
23
+ method: str = "none"
24
+
25
+
26
+ def importance_sampling(
27
+ samples: NDArray,
28
+ logP: NDArray,
29
+ logQ: NDArray,
30
+ num_draws: int,
31
+ method: Literal["psis", "psir", "identity", "none"] | None,
32
+ random_seed: int | None = None,
33
+ ) -> ImportanceSamplingResult:
34
+ """Pareto Smoothed Importance Resampling (PSIR)
35
+ This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
36
+
37
+ Parameters
38
+ ----------
39
+ samples : NDArray
40
+ samples from proposal distribution, shape (L, M, N)
41
+ logP : NDArray
42
+ log probability values of target distribution, shape (L, M)
43
+ logQ : NDArray
44
+ log probability values of proposal distribution, shape (L, M)
45
+ num_draws : int
46
+ number of draws to return where num_draws <= samples.shape[0]
47
+ method : str, optional
48
+ importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
49
+ random_seed : int | None
50
+
51
+ Returns
52
+ -------
53
+ ImportanceSamplingResult
54
+ importance sampled draws and other info based on the specified method
55
+
56
+ Future work!
57
+ ----------
58
+ - Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
59
+ - Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
60
+ - Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.
61
+
62
+ References
63
+ ----------
64
+ Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668
65
+
66
+ Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538
67
+
68
+ Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
69
+ """
70
+
71
+ warnings = []
72
+ num_paths, _, N = samples.shape
73
+
74
+ if method == "none":
75
+ warnings.append(
76
+ "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
77
+ )
78
+ return ImportanceSamplingResult(samples=samples, warnings=warnings)
79
+ else:
80
+ samples = samples.reshape(-1, N)
81
+ logP = logP.ravel()
82
+ logQ = logQ.ravel()
83
+
84
+ # adjust log densities
85
+ log_I = np.log(num_paths)
86
+ logP -= log_I
87
+ logQ -= log_I
88
+ logiw = logP - logQ
89
+
90
+ with _warnings.catch_warnings():
91
+ _warnings.filterwarnings(
92
+ "ignore", category=RuntimeWarning, message="overflow encountered in exp"
93
+ )
94
+ if method == "psis":
95
+ replace = False
96
+ logiw, pareto_k = az.psislw(logiw)
97
+ elif method == "psir":
98
+ replace = True
99
+ logiw, pareto_k = az.psislw(logiw)
100
+ elif method == "identity":
101
+ replace = False
102
+ pareto_k = None
103
+ else:
104
+ raise ValueError(f"Invalid importance sampling method: {method}")
105
+
106
+ # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107
+ # Pareto k may not be a good diagnostic for Pathfinder.
108
+ # TODO: Find replacement diagnostics for Pathfinder.
109
+
110
+ p = np.exp(logiw - logsumexp(logiw))
111
+ rng = np.random.default_rng(random_seed)
112
+
113
+ try:
114
+ resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
115
+ return ImportanceSamplingResult(
116
+ samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
117
+ )
118
+ except ValueError as e1:
119
+ if "Fewer non-zero entries in p than size" in str(e1):
120
+ num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
121
+ warnings.append(
122
+ f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
123
+ )
124
+ try:
125
+ resampled = rng.choice(
126
+ samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
127
+ )
128
+ return ImportanceSamplingResult(
129
+ samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
130
+ )
131
+ except ValueError as e2:
132
+ logger.error(
133
+ "Importance sampling failed even with psir importance sampling. "
134
+ "This might indicate invalid probability weights or insufficient valid samples."
135
+ )
136
+ raise ValueError(
137
+ "Importance sampling failed for both with and without replacement"
138
+ ) from e2
139
+ raise
@@ -0,0 +1,190 @@
1
+ import logging
2
+
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum, auto
6
+
7
+ import numpy as np
8
+
9
+ from numpy.typing import NDArray
10
+ from scipy.optimize import minimize
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass(slots=True)
16
+ class LBFGSHistory:
17
+ """History of LBFGS iterations."""
18
+
19
+ x: NDArray[np.float64]
20
+ g: NDArray[np.float64]
21
+ count: int
22
+
23
+ def __post_init__(self):
24
+ self.x = np.ascontiguousarray(self.x, dtype=np.float64)
25
+ self.g = np.ascontiguousarray(self.g, dtype=np.float64)
26
+
27
+
28
+ @dataclass(slots=True)
29
+ class LBFGSHistoryManager:
30
+ """manages and stores the history of lbfgs optimisation iterations.
31
+
32
+ Parameters
33
+ ----------
34
+ value_grad_fn : Callable
35
+ function that returns tuple of (value, gradient) given input x
36
+ x0 : NDArray
37
+ initial position
38
+ maxiter : int
39
+ maximum number of iterations to store
40
+ """
41
+
42
+ value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
43
+ x0: NDArray[np.float64]
44
+ maxiter: int
45
+ x_history: NDArray[np.float64] = field(init=False)
46
+ g_history: NDArray[np.float64] = field(init=False)
47
+ count: int = field(init=False)
48
+
49
+ def __post_init__(self) -> None:
50
+ self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
51
+ self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
52
+ self.count = 0
53
+
54
+ value, grad = self.value_grad_fn(self.x0)
55
+ if np.all(np.isfinite(grad)) and np.isfinite(value):
56
+ self.add_entry(self.x0, grad)
57
+
58
+ def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
59
+ """adds new position and gradient to history.
60
+
61
+ Parameters
62
+ ----------
63
+ x : NDArray
64
+ position vector
65
+ g : NDArray
66
+ gradient vector
67
+ """
68
+ self.x_history[self.count] = x
69
+ self.g_history[self.count] = g
70
+ self.count += 1
71
+
72
+ def get_history(self) -> LBFGSHistory:
73
+ """returns history of optimisation iterations."""
74
+ return LBFGSHistory(
75
+ x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
76
+ )
77
+
78
+ def __call__(self, x: NDArray[np.float64]) -> None:
79
+ value, grad = self.value_grad_fn(x)
80
+ if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
81
+ self.add_entry(x, grad)
82
+
83
+
84
+ class LBFGSStatus(Enum):
85
+ CONVERGED = auto()
86
+ MAX_ITER_REACHED = auto()
87
+ DIVERGED = auto()
88
+ # Statuses that lead to Exceptions:
89
+ INIT_FAILED = auto()
90
+ LBFGS_FAILED = auto()
91
+
92
+
93
+ class LBFGSException(Exception):
94
+ DEFAULT_MESSAGE = "LBFGS failed."
95
+
96
+ def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED):
97
+ super().__init__(message or self.DEFAULT_MESSAGE)
98
+ self.status = status
99
+
100
+
101
+ class LBFGSInitFailed(LBFGSException):
102
+ DEFAULT_MESSAGE = "LBFGS failed to initialise."
103
+
104
+ def __init__(self, message=None):
105
+ super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
106
+
107
+
108
+ class LBFGS:
109
+ """L-BFGS optimizer wrapper around scipy's implementation.
110
+
111
+ Parameters
112
+ ----------
113
+ value_grad_fn : Callable
114
+ function that returns tuple of (value, gradient) given input x
115
+ maxcor : int
116
+ maximum number of variable metric corrections
117
+ maxiter : int, optional
118
+ maximum number of iterations, defaults to 1000
119
+ ftol : float, optional
120
+ function tolerance for convergence, defaults to 1e-5
121
+ gtol : float, optional
122
+ gradient tolerance for convergence, defaults to 1e-8
123
+ maxls : int, optional
124
+ maximum number of line search steps, defaults to 1000
125
+ """
126
+
127
+ def __init__(
128
+ self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
129
+ ) -> None:
130
+ self.value_grad_fn = value_grad_fn
131
+ self.maxcor = maxcor
132
+ self.maxiter = maxiter
133
+ self.ftol = ftol
134
+ self.gtol = gtol
135
+ self.maxls = maxls
136
+
137
+ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
138
+ """minimizes objective function starting from initial position.
139
+
140
+ Parameters
141
+ ----------
142
+ x0 : array_like
143
+ initial position
144
+
145
+ Returns
146
+ -------
147
+ x : NDArray
148
+ history of positions
149
+ g : NDArray
150
+ history of gradients
151
+ count : int
152
+ number of iterations
153
+ status : LBFGSStatus
154
+ final status of optimisation
155
+ """
156
+
157
+ x0 = np.array(x0, dtype=np.float64)
158
+
159
+ history_manager = LBFGSHistoryManager(
160
+ value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
161
+ )
162
+
163
+ result = minimize(
164
+ self.value_grad_fn,
165
+ x0,
166
+ method="L-BFGS-B",
167
+ jac=True,
168
+ callback=history_manager,
169
+ options={
170
+ "maxcor": self.maxcor,
171
+ "maxiter": self.maxiter,
172
+ "ftol": self.ftol,
173
+ "gtol": self.gtol,
174
+ "maxls": self.maxls,
175
+ },
176
+ )
177
+ history = history_manager.get_history()
178
+
179
+ # warnings and suggestions for LBFGSStatus are displayed at the end
180
+ if result.status == 1:
181
+ lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
182
+ elif (result.status == 2) or (history.count <= 1):
183
+ if result.nit <= 1:
184
+ lbfgs_status = LBFGSStatus.INIT_FAILED
185
+ elif result.fun == np.inf:
186
+ lbfgs_status = LBFGSStatus.DIVERGED
187
+ else:
188
+ lbfgs_status = LBFGSStatus.CONVERGED
189
+
190
+ return history.x, history.g, history.count, lbfgs_status