pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pymc_extras/__init__.py CHANGED
@@ -15,9 +15,7 @@ import logging
15
15
 
16
16
  from pymc_extras import gp, statespace, utils
17
17
  from pymc_extras.distributions import *
18
- from pymc_extras.inference.find_map import find_MAP
19
- from pymc_extras.inference.fit import fit
20
- from pymc_extras.inference.laplace import fit_laplace
18
+ from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
21
19
  from pymc_extras.model.marginal.marginal_model import (
22
20
  MarginalModel,
23
21
  marginalize,
@@ -26,6 +26,7 @@ from pymc_extras.distributions.discrete import (
26
26
  from pymc_extras.distributions.histogram_utils import histogram_approximation
27
27
  from pymc_extras.distributions.multivariate import R2D2M2CP
28
28
  from pymc_extras.distributions.timeseries import DiscreteMarkovChain
29
+ from pymc_extras.distributions.transforms import PartialOrder
29
30
 
30
31
  __all__ = [
31
32
  "Chi",
@@ -37,4 +38,5 @@ __all__ = [
37
38
  "R2D2M2CP",
38
39
  "Skellam",
39
40
  "histogram_approximation",
41
+ "PartialOrder",
40
42
  ]
@@ -0,0 +1,3 @@
1
+ from pymc_extras.distributions.transforms.partial_order import PartialOrder
2
+
3
+ __all__ = ["PartialOrder"]
@@ -0,0 +1,227 @@
1
+ # Copyright 2025 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import pytensor.tensor as pt
16
+
17
+ from pymc.logprob.transforms import Transform
18
+
19
+ __all__ = ["PartialOrder"]
20
+
21
+
22
+ def dtype_minval(dtype):
23
+ """Find the minimum value for a given dtype"""
24
+ return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
25
+
26
+
27
+ def padded_where(x, to_len, padval=-1):
28
+ """A padded version of np.where"""
29
+ w = np.where(x)
30
+ return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
31
+
32
+
33
+ class PartialOrder(Transform):
34
+ """Create a PartialOrder transform
35
+
36
+ A more flexible version of the pymc ordered transform that
37
+ allows specifying a (strict) partial order on the elements.
38
+
39
+ Examples
40
+ --------
41
+ .. code:: python
42
+
43
+ import numpy as np
44
+ import pymc as pm
45
+ import pymc_extras as pmx
46
+
47
+ # Define two partial orders on 4 elements
48
+ # am[i,j] = 1 means i < j
49
+ adj_mats = np.array([
50
+ # 0 < {1, 2} < 3
51
+ [[0, 1, 1, 0],
52
+ [0, 0, 0, 1],
53
+ [0, 0, 0, 1],
54
+ [0, 0, 0, 0]],
55
+
56
+ # 1 < 0 < 3 < 2
57
+ [[0, 0, 0, 1],
58
+ [1, 0, 0, 0],
59
+ [0, 0, 0, 0],
60
+ [0, 0, 1, 0]],
61
+ ])
62
+
63
+ # Create the partial order from the adjacency matrices
64
+ po = pmx.PartialOrder(adj_mats)
65
+
66
+ with pm.Model() as model:
67
+ # Generate 3 samples from both partial orders
68
+ pm.Normal("po_vals", shape=(3,2,4), transform=po,
69
+ initval=po.initvals((3,2,4)))
70
+
71
+ idata = pm.sample()
72
+
73
+ # Verify that for first po, the zeroth element is always the smallest
74
+ assert (idata.posterior['po_vals'][:,:,:,0,0] <
75
+ idata.posterior['po_vals'][:,:,:,0,1:]).all()
76
+
77
+ # Verify that for second po, the second element is always the largest
78
+ assert (idata.posterior['po_vals'][:,:,:,1,2] >=
79
+ idata.posterior['po_vals'][:,:,:,1,:]).all()
80
+
81
+ Technical notes
82
+ ----------------
83
+ Partial order needs to be strict, i.e. without equalities.
84
+ A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
85
+ Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
86
+ where N is the number of nodes in the dag and D is the maximum
87
+ in-degree of a node in the transitive reduction.
88
+ """
89
+
90
+ name = "partial_order"
91
+
92
+ def __init__(self, adj_mat):
93
+ """
94
+ Initialize the PartialOrder transform
95
+
96
+ Parameters
97
+ ----------
98
+ adj_mat: ndarray
99
+ adjacency matrix for the DAG that generates the partial order,
100
+ where ``adj_mat[i][j] = 1`` denotes ``i < j``.
101
+ Note this also accepts multiple DAGs if RV is multidimensional
102
+ """
103
+
104
+ # Basic input checks
105
+ if adj_mat.ndim < 2:
106
+ raise ValueError("Adjacency matrix must have at least 2 dimensions")
107
+ if adj_mat.shape[-2] != adj_mat.shape[-1]:
108
+ raise ValueError("Adjacency matrix is not square")
109
+ if adj_mat.min() != 0 or adj_mat.max() != 1:
110
+ raise ValueError("Adjacency matrix must contain only 0s and 1s")
111
+
112
+ # Create index over the first ellipsis dimensions
113
+ idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
114
+
115
+ # Transitive closure using Floyd-Warshall
116
+ tc = adj_mat.astype(bool)
117
+ for k in range(tc.shape[-1]):
118
+ tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
119
+
120
+ # Check if the dag is acyclic
121
+ if np.any(tc.diagonal(axis1=-2, axis2=-1)):
122
+ raise ValueError("Partial order contains equalities")
123
+
124
+ # Transitive reduction using the closure
125
+ # This gives the minimum description of the partial order
126
+ # This is to minmax the input degree
127
+ adj_mat = tc * (1 - np.matmul(tc, tc))
128
+
129
+ # Find the maximum in-degree of the reduced dag
130
+ dag_idim = adj_mat.sum(axis=-2).max()
131
+
132
+ # Topological sort
133
+ ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
134
+ dm = adj_mat.copy()
135
+ for i in range(adj_mat.shape[1]):
136
+ assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
137
+ nind = np.argmin(dm.sum(axis=-2), axis=-1)
138
+ dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
139
+ dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
140
+ ts_inds[(*idx, i)] = nind
141
+ self.ts_inds = ts_inds
142
+
143
+ # Change the dag to adjacency lists (with -1 for NA)
144
+ dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
145
+ self.dag = np.swapaxes(dag_T, -2, -1)
146
+ self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
147
+
148
+ def initvals(self, shape=None, lower=-1, upper=1):
149
+ """
150
+ Create a set of appropriate initial values for the variable.
151
+ NB! It is important that proper initial values are used,
152
+ as only properly ordered values are in the range of the transform.
153
+
154
+ Parameters
155
+ ----------
156
+ shape: tuple, default None
157
+ shape of the initial values. If None, adj_mat[:-1] is used
158
+ lower: float, default -1
159
+ lower bound for the initial values
160
+ upper: float, default 1
161
+ upper bound for the initial values
162
+
163
+ Returns
164
+ -------
165
+ vals: ndarray
166
+ initial values for the transformed variable
167
+ """
168
+
169
+ if shape is None:
170
+ shape = self.dag.shape[:-1]
171
+
172
+ if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
173
+ raise ValueError("Shape must match the shape of the adjacency matrix")
174
+
175
+ # Create the initial values
176
+ vals = np.linspace(lower, upper, self.dag.shape[-2])
177
+ inds = np.argsort(self.ts_inds, axis=-1)
178
+ ivals = vals[inds]
179
+
180
+ # Expand the initial values to the extra dimensions
181
+ extra_dims = shape[: -len(self.dag.shape[:-1])]
182
+ ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
183
+
184
+ return ivals
185
+
186
+ def backward(self, value, *inputs):
187
+ minv = dtype_minval(value.dtype)
188
+ x = pt.concatenate(
189
+ [pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
190
+ )
191
+
192
+ # Indices to allow broadcasting the max over the last dimension
193
+ idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
194
+ idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
195
+
196
+ # Has to be done stepwise as next steps depend on previous values
197
+ # Also has to be done in topological order, hence the ts_inds
198
+ for i in range(self.dag.shape[-2]):
199
+ tsi = self.ts_inds[..., i]
200
+ if len(tsi.shape) == 0:
201
+ tsi = int(tsi) # if shape 0, it's a scalar
202
+ ni = (*idx, tsi) # i-th node in topological order
203
+ eni = (Ellipsis, *ni)
204
+ ist = self.is_start[ni]
205
+
206
+ mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
207
+ x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
208
+ return x[..., :-1]
209
+
210
+ def forward(self, value, *inputs):
211
+ y = pt.zeros_like(value)
212
+
213
+ minv = dtype_minval(value.dtype)
214
+ vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
215
+
216
+ # Indices to allow broadcasting the max over the last dimension
217
+ idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
218
+ idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
219
+
220
+ y = self.is_start * value + (1 - self.is_start) * (
221
+ pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
222
+ )
223
+
224
+ return y
225
+
226
+ def log_jac_det(self, value, *inputs):
227
+ return pt.sum(value * (1 - self.is_start), axis=-1)
@@ -12,7 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
15
+ from pymc_extras.inference.find_map import find_MAP
16
16
  from pymc_extras.inference.fit import fit
17
+ from pymc_extras.inference.laplace import fit_laplace
18
+ from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
17
19
 
18
- __all__ = ["fit"]
20
+ __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
@@ -11,11 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import arviz as az
14
15
 
15
16
 
16
- def fit(method, **kwargs):
17
+ def fit(method: str, **kwargs) -> az.InferenceData:
17
18
  """
18
- Fit a model with an inference algorithm
19
+ Fit a model with an inference algorithm.
20
+ See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
19
21
 
20
22
  Parameters
21
23
  ----------
@@ -23,11 +25,11 @@ def fit(method, **kwargs):
23
25
  Which inference method to run.
24
26
  Supported: pathfinder or laplace
25
27
 
26
- kwargs are passed on.
28
+ kwargs: keyword arguments are passed on to the inference method.
27
29
 
28
30
  Returns
29
31
  -------
30
- arviz.InferenceData
32
+ :class:`~arviz.InferenceData`
31
33
  """
32
34
  if method == "pathfinder":
33
35
  from pymc_extras.inference.pathfinder import fit_pathfinder
@@ -377,7 +377,10 @@ def sample_laplace_posterior(
377
377
  posterior_dist = stats.multivariate_normal(
378
378
  mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
379
379
  )
380
+
380
381
  posterior_draws = posterior_dist.rvs(size=(chains, draws))
382
+ if mu.data.shape == (1,):
383
+ posterior_draws = np.expand_dims(posterior_draws, -1)
381
384
 
382
385
  if transform_samples:
383
386
  constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
@@ -506,7 +509,7 @@ def fit_laplace(
506
509
 
507
510
  Returns
508
511
  -------
509
- idata: az.InferenceData
512
+ :class:`~arviz.InferenceData`
510
513
  An InferenceData object containing the approximated posterior samples.
511
514
 
512
515
  Examples
@@ -20,7 +20,7 @@ class ImportanceSamplingResult:
20
20
  samples: NDArray
21
21
  pareto_k: float | None = None
22
22
  warnings: list[str] = field(default_factory=list)
23
- method: str = "none"
23
+ method: str = "psis"
24
24
 
25
25
 
26
26
  def importance_sampling(
@@ -28,7 +28,7 @@ def importance_sampling(
28
28
  logP: NDArray,
29
29
  logQ: NDArray,
30
30
  num_draws: int,
31
- method: Literal["psis", "psir", "identity", "none"] | None,
31
+ method: Literal["psis", "psir", "identity"] | None,
32
32
  random_seed: int | None = None,
33
33
  ) -> ImportanceSamplingResult:
34
34
  """Pareto Smoothed Importance Resampling (PSIR)
@@ -44,8 +44,15 @@ def importance_sampling(
44
44
  log probability values of proposal distribution, shape (L, M)
45
45
  num_draws : int
46
46
  number of draws to return where num_draws <= samples.shape[0]
47
- method : str, optional
48
- importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
47
+ method : str, None, optional
48
+ Method to apply sampling based on log importance weights (logP - logQ).
49
+ Options are:
50
+ "psis" : Pareto Smoothed Importance Sampling (default)
51
+ Recommended for more stable results.
52
+ "psir" : Pareto Smoothed Importance Resampling
53
+ Less stable than PSIS.
54
+ "identity" : Applies log importance weights directly without resampling.
55
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
49
56
  random_seed : int | None
50
57
 
51
58
  Returns
@@ -71,11 +78,11 @@ def importance_sampling(
71
78
  warnings = []
72
79
  num_paths, _, N = samples.shape
73
80
 
74
- if method == "none":
81
+ if method is None:
75
82
  warnings.append(
76
83
  "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
77
84
  )
78
- return ImportanceSamplingResult(samples=samples, warnings=warnings)
85
+ return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
79
86
  else:
80
87
  samples = samples.reshape(-1, N)
81
88
  logP = logP.ravel()
@@ -91,17 +98,16 @@ def importance_sampling(
91
98
  _warnings.filterwarnings(
92
99
  "ignore", category=RuntimeWarning, message="overflow encountered in exp"
93
100
  )
94
- if method == "psis":
95
- replace = False
96
- logiw, pareto_k = az.psislw(logiw)
97
- elif method == "psir":
98
- replace = True
99
- logiw, pareto_k = az.psislw(logiw)
100
- elif method == "identity":
101
- replace = False
102
- pareto_k = None
103
- else:
104
- raise ValueError(f"Invalid importance sampling method: {method}")
101
+ match method:
102
+ case "psis":
103
+ replace = False
104
+ logiw, pareto_k = az.psislw(logiw)
105
+ case "psir":
106
+ replace = True
107
+ logiw, pareto_k = az.psislw(logiw)
108
+ case "identity":
109
+ replace = False
110
+ pareto_k = None
105
111
 
106
112
  # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107
113
  # Pareto k may not be a good diagnostic for Pathfinder.
@@ -37,11 +37,14 @@ class LBFGSHistoryManager:
37
37
  initial position
38
38
  maxiter : int
39
39
  maximum number of iterations to store
40
+ epsilon : float
41
+ tolerance for lbfgs update
40
42
  """
41
43
 
42
44
  value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
43
45
  x0: NDArray[np.float64]
44
46
  maxiter: int
47
+ epsilon: float
45
48
  x_history: NDArray[np.float64] = field(init=False)
46
49
  g_history: NDArray[np.float64] = field(init=False)
47
50
  count: int = field(init=False)
@@ -52,7 +55,7 @@ class LBFGSHistoryManager:
52
55
  self.count = 0
53
56
 
54
57
  value, grad = self.value_grad_fn(self.x0)
55
- if np.all(np.isfinite(grad)) and np.isfinite(value):
58
+ if self.entry_condition_met(self.x0, value, grad):
56
59
  self.add_entry(self.x0, grad)
57
60
 
58
61
  def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
@@ -75,18 +78,39 @@ class LBFGSHistoryManager:
75
78
  x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
76
79
  )
77
80
 
81
+ def entry_condition_met(self, x, value, grad) -> bool:
82
+ """Checks if the LBFGS iteration should continue."""
83
+
84
+ if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1):
85
+ if self.count == 0:
86
+ return True
87
+ else:
88
+ s = x - self.x_history[self.count - 1]
89
+ z = grad - self.g_history[self.count - 1]
90
+ sz = (s * z).sum(axis=-1)
91
+ update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1))
92
+
93
+ if update:
94
+ return True
95
+ else:
96
+ return False
97
+ else:
98
+ return False
99
+
78
100
  def __call__(self, x: NDArray[np.float64]) -> None:
79
101
  value, grad = self.value_grad_fn(x)
80
- if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
102
+ if self.entry_condition_met(x, value, grad):
81
103
  self.add_entry(x, grad)
82
104
 
83
105
 
84
106
  class LBFGSStatus(Enum):
85
107
  CONVERGED = auto()
86
108
  MAX_ITER_REACHED = auto()
87
- DIVERGED = auto()
109
+ NON_FINITE = auto()
110
+ LOW_UPDATE_PCT = auto()
88
111
  # Statuses that lead to Exceptions:
89
112
  INIT_FAILED = auto()
113
+ INIT_FAILED_LOW_UPDATE_PCT = auto()
90
114
  LBFGS_FAILED = auto()
91
115
 
92
116
 
@@ -101,8 +125,8 @@ class LBFGSException(Exception):
101
125
  class LBFGSInitFailed(LBFGSException):
102
126
  DEFAULT_MESSAGE = "LBFGS failed to initialise."
103
127
 
104
- def __init__(self, message=None):
105
- super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
128
+ def __init__(self, status: LBFGSStatus, message=None):
129
+ super().__init__(message or self.DEFAULT_MESSAGE, status)
106
130
 
107
131
 
108
132
  class LBFGS:
@@ -122,10 +146,12 @@ class LBFGS:
122
146
  gradient tolerance for convergence, defaults to 1e-8
123
147
  maxls : int, optional
124
148
  maximum number of line search steps, defaults to 1000
149
+ epsilon : float, optional
150
+ tolerance for lbfgs update, defaults to 1e-8
125
151
  """
126
152
 
127
153
  def __init__(
128
- self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
154
+ self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8
129
155
  ) -> None:
130
156
  self.value_grad_fn = value_grad_fn
131
157
  self.maxcor = maxcor
@@ -133,6 +159,7 @@ class LBFGS:
133
159
  self.ftol = ftol
134
160
  self.gtol = gtol
135
161
  self.maxls = maxls
162
+ self.epsilon = epsilon
136
163
 
137
164
  def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
138
165
  """minimizes objective function starting from initial position.
@@ -157,7 +184,7 @@ class LBFGS:
157
184
  x0 = np.array(x0, dtype=np.float64)
158
185
 
159
186
  history_manager = LBFGSHistoryManager(
160
- value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
187
+ value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon
161
188
  )
162
189
 
163
190
  result = minimize(
@@ -177,13 +204,22 @@ class LBFGS:
177
204
  history = history_manager.get_history()
178
205
 
179
206
  # warnings and suggestions for LBFGSStatus are displayed at the end
180
- if result.status == 1:
181
- lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
182
- elif (result.status == 2) or (history.count <= 1):
183
- if result.nit <= 1:
207
+ # threshold determining if the number of lbfgs updates is low compared to iterations
208
+ low_update_threshold = 3
209
+
210
+ if history.count <= 1: # triggers LBFGSInitFailed
211
+ if result.nit < low_update_threshold:
184
212
  lbfgs_status = LBFGSStatus.INIT_FAILED
185
- elif result.fun == np.inf:
186
- lbfgs_status = LBFGSStatus.DIVERGED
213
+ else:
214
+ lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT
215
+ elif result.status == 1:
216
+ # (result.nit > maxiter) or (result.nit > maxls)
217
+ lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
218
+ elif result.status == 2:
219
+ # precision loss resulting to inf or nan
220
+ lbfgs_status = LBFGSStatus.NON_FINITE
221
+ elif history.count * low_update_threshold < result.nit:
222
+ lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
187
223
  else:
188
224
  lbfgs_status = LBFGSStatus.CONVERGED
189
225