pymc-extras 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,7 @@ class ImportanceSamplingResult:
20
20
  samples: NDArray
21
21
  pareto_k: float | None = None
22
22
  warnings: list[str] = field(default_factory=list)
23
- method: str = "none"
23
+ method: str = "psis"
24
24
 
25
25
 
26
26
  def importance_sampling(
@@ -28,7 +28,7 @@ def importance_sampling(
28
28
  logP: NDArray,
29
29
  logQ: NDArray,
30
30
  num_draws: int,
31
- method: Literal["psis", "psir", "identity", "none"] | None,
31
+ method: Literal["psis", "psir", "identity"] | None,
32
32
  random_seed: int | None = None,
33
33
  ) -> ImportanceSamplingResult:
34
34
  """Pareto Smoothed Importance Resampling (PSIR)
@@ -44,8 +44,15 @@ def importance_sampling(
44
44
  log probability values of proposal distribution, shape (L, M)
45
45
  num_draws : int
46
46
  number of draws to return where num_draws <= samples.shape[0]
47
- method : str, optional
48
- importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
47
+ method : str, None, optional
48
+ Method to apply sampling based on log importance weights (logP - logQ).
49
+ Options are:
50
+ "psis" : Pareto Smoothed Importance Sampling (default)
51
+ Recommended for more stable results.
52
+ "psir" : Pareto Smoothed Importance Resampling
53
+ Less stable than PSIS.
54
+ "identity" : Applies log importance weights directly without resampling.
55
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
49
56
  random_seed : int | None
50
57
 
51
58
  Returns
@@ -71,11 +78,11 @@ def importance_sampling(
71
78
  warnings = []
72
79
  num_paths, _, N = samples.shape
73
80
 
74
- if method == "none":
81
+ if method is None:
75
82
  warnings.append(
76
83
  "Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
77
84
  )
78
- return ImportanceSamplingResult(samples=samples, warnings=warnings)
85
+ return ImportanceSamplingResult(samples=samples, warnings=warnings, method=method)
79
86
  else:
80
87
  samples = samples.reshape(-1, N)
81
88
  logP = logP.ravel()
@@ -91,17 +98,16 @@ def importance_sampling(
91
98
  _warnings.filterwarnings(
92
99
  "ignore", category=RuntimeWarning, message="overflow encountered in exp"
93
100
  )
94
- if method == "psis":
95
- replace = False
96
- logiw, pareto_k = az.psislw(logiw)
97
- elif method == "psir":
98
- replace = True
99
- logiw, pareto_k = az.psislw(logiw)
100
- elif method == "identity":
101
- replace = False
102
- pareto_k = None
103
- else:
104
- raise ValueError(f"Invalid importance sampling method: {method}")
101
+ match method:
102
+ case "psis":
103
+ replace = False
104
+ logiw, pareto_k = az.psislw(logiw)
105
+ case "psir":
106
+ replace = True
107
+ logiw, pareto_k = az.psislw(logiw)
108
+ case "identity":
109
+ replace = False
110
+ pareto_k = None
105
111
 
106
112
  # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
107
113
  # Pareto k may not be a good diagnostic for Pathfinder.
@@ -60,6 +60,7 @@ from pytensor.graph import Apply, Op, vectorize_graph
60
60
  from pytensor.tensor import TensorConstant, TensorVariable
61
61
  from rich.console import Console, Group
62
62
  from rich.padding import Padding
63
+ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
63
64
  from rich.table import Table
64
65
  from rich.text import Text
65
66
 
@@ -155,7 +156,7 @@ def convert_flat_trace_to_idata(
155
156
  postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
156
157
  inference_backend: Literal["pymc", "blackjax"] = "pymc",
157
158
  model: Model | None = None,
158
- importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
159
+ importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
159
160
  ) -> az.InferenceData:
160
161
  """convert flattened samples to arviz InferenceData format.
161
162
 
@@ -180,7 +181,7 @@ def convert_flat_trace_to_idata(
180
181
  arviz inference data object
181
182
  """
182
183
 
183
- if importance_sampling == "none":
184
+ if importance_sampling is None:
184
185
  # samples.ndim == 3 in this case, otherwise ndim == 2
185
186
  num_paths, num_pdraws, N = samples.shape
186
187
  samples = samples.reshape(-1, N)
@@ -219,7 +220,7 @@ def convert_flat_trace_to_idata(
219
220
  fn.trust_input = True
220
221
  result = fn(*list(trace.values()))
221
222
 
222
- if importance_sampling == "none":
223
+ if importance_sampling is None:
223
224
  result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
224
225
 
225
226
  elif inference_backend == "blackjax":
@@ -1188,7 +1189,7 @@ class MultiPathfinderResult:
1188
1189
  elbo_argmax: NDArray | None = None
1189
1190
  lbfgs_status: Counter = field(default_factory=Counter)
1190
1191
  path_status: Counter = field(default_factory=Counter)
1191
- importance_sampling: str = "none"
1192
+ importance_sampling: str | None = "psis"
1192
1193
  warnings: list[str] = field(default_factory=list)
1193
1194
  pareto_k: float | None = None
1194
1195
 
@@ -1257,7 +1258,7 @@ class MultiPathfinderResult:
1257
1258
  def with_importance_sampling(
1258
1259
  self,
1259
1260
  num_draws: int,
1260
- method: Literal["psis", "psir", "identity", "none"] | None,
1261
+ method: Literal["psis", "psir", "identity"] | None,
1261
1262
  random_seed: int | None = None,
1262
1263
  ) -> Self:
1263
1264
  """perform importance sampling"""
@@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
1395
1396
 
1396
1397
  path_status_message = {
1397
1398
  PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1398
- PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1399
+ PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1399
1400
  PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1400
1401
  }
1401
1402
 
@@ -1423,7 +1424,7 @@ def multipath_pathfinder(
1423
1424
  num_elbo_draws: int,
1424
1425
  jitter: float,
1425
1426
  epsilon: float,
1426
- importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
1427
+ importance_sampling: Literal["psis", "psir", "identity"] | None,
1427
1428
  progressbar: bool,
1428
1429
  concurrent: Literal["thread", "process"] | None,
1429
1430
  random_seed: RandomSeed,
@@ -1459,8 +1460,14 @@ def multipath_pathfinder(
1459
1460
  Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1460
1461
  epsilon: float
1461
1462
  value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1462
- importance_sampling : str, optional
1463
- importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1463
+ importance_sampling : str, None, optional
1464
+ Method to apply sampling based on log importance weights (logP - logQ).
1465
+ "psis" : Pareto Smoothed Importance Sampling (default)
1466
+ Recommended for more stable results.
1467
+ "psir" : Pareto Smoothed Importance Resampling
1468
+ Less stable than PSIS.
1469
+ "identity" : Applies log importance weights directly without resampling.
1470
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1464
1471
  progressbar : bool, optional
1465
1472
  Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
1466
1473
  random_seed : RandomSeed, optional
@@ -1482,12 +1489,6 @@ def multipath_pathfinder(
1482
1489
  The result containing samples and other information from the Multi-Path Pathfinder algorithm.
1483
1490
  """
1484
1491
 
1485
- valid_importance_sampling = ["psis", "psir", "identity", "none", None]
1486
- if importance_sampling is None:
1487
- importance_sampling = "none"
1488
- if importance_sampling.lower() not in valid_importance_sampling:
1489
- raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1490
-
1491
1492
  *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
1492
1493
 
1493
1494
  pathfinder_config = PathfinderConfig(
@@ -1521,12 +1522,20 @@ def multipath_pathfinder(
1521
1522
  results = []
1522
1523
  compute_start = time.time()
1523
1524
  try:
1524
- with CustomProgress(
1525
+ desc = f"Paths Complete: {{path_idx}}/{num_paths}"
1526
+ progress = CustomProgress(
1527
+ "[progress.description]{task.description}",
1528
+ BarColumn(),
1529
+ "[progress.percentage]{task.percentage:>3.0f}%",
1530
+ TimeRemainingColumn(),
1531
+ TextColumn("/"),
1532
+ TimeElapsedColumn(),
1525
1533
  console=Console(theme=default_progress_theme),
1526
1534
  disable=not progressbar,
1527
- ) as progress:
1528
- task = progress.add_task("Fitting", total=num_paths)
1529
- for result in generator:
1535
+ )
1536
+ with progress:
1537
+ task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
1538
+ for path_idx, result in enumerate(generator, start=1):
1530
1539
  try:
1531
1540
  if isinstance(result, Exception):
1532
1541
  raise result
@@ -1552,7 +1561,14 @@ def multipath_pathfinder(
1552
1561
  lbfgs_status=LBFGSStatus.LBFGS_FAILED,
1553
1562
  )
1554
1563
  )
1555
- progress.update(task, advance=1)
1564
+ finally:
1565
+ # TODO: display LBFGS and Path Status in real time
1566
+ progress.update(
1567
+ task,
1568
+ description=desc.format(path_idx=path_idx),
1569
+ completed=path_idx,
1570
+ refresh=True,
1571
+ )
1556
1572
  except (KeyboardInterrupt, StopIteration) as e:
1557
1573
  # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
1558
1574
  if isinstance(e, StopIteration):
@@ -1606,7 +1622,7 @@ def fit_pathfinder(
1606
1622
  num_elbo_draws: int = 10, # K
1607
1623
  jitter: float = 2.0,
1608
1624
  epsilon: float = 1e-8,
1609
- importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
1625
+ importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
1610
1626
  progressbar: bool = True,
1611
1627
  concurrent: Literal["thread", "process"] | None = None,
1612
1628
  random_seed: RandomSeed | None = None,
@@ -1646,8 +1662,15 @@ def fit_pathfinder(
1646
1662
  Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1647
1663
  epsilon: float
1648
1664
  value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1649
- importance_sampling : str, optional
1650
- importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1665
+ importance_sampling : str, None, optional
1666
+ Method to apply sampling based on log importance weights (logP - logQ).
1667
+ Options are:
1668
+ "psis" : Pareto Smoothed Importance Sampling (default)
1669
+ Recommended for more stable results.
1670
+ "psir" : Pareto Smoothed Importance Resampling
1671
+ Less stable than PSIS.
1672
+ "identity" : Applies log importance weights directly without resampling.
1673
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1651
1674
  progressbar : bool, optional
1652
1675
  Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
1653
1676
  random_seed : RandomSeed, optional
@@ -1674,6 +1697,15 @@ def fit_pathfinder(
1674
1697
  """
1675
1698
 
1676
1699
  model = modelcontext(model)
1700
+
1701
+ valid_importance_sampling = {"psis", "psir", "identity", None}
1702
+
1703
+ if importance_sampling is not None:
1704
+ importance_sampling = importance_sampling.lower()
1705
+
1706
+ if importance_sampling not in valid_importance_sampling:
1707
+ raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1708
+
1677
1709
  N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
1678
1710
 
1679
1711
  if maxcor is None:
@@ -28,7 +28,6 @@ from pymc_extras.statespace.filters import (
28
28
  )
29
29
  from pymc_extras.statespace.filters.distributions import (
30
30
  LinearGaussianStateSpace,
31
- MvNormalSVD,
32
31
  SequenceMvNormal,
33
32
  )
34
33
  from pymc_extras.statespace.filters.utilities import stabilize
@@ -707,7 +706,7 @@ class PyMCStateSpace:
707
706
  with pymc_model:
708
707
  for param_name in self.param_names:
709
708
  param = getattr(pymc_model, param_name, None)
710
- if param:
709
+ if param is not None:
711
710
  found_params.append(param.name)
712
711
 
713
712
  missing_params = list(set(self.param_names) - set(found_params))
@@ -746,7 +745,7 @@ class PyMCStateSpace:
746
745
  with pymc_model:
747
746
  for data_name in data_names:
748
747
  data = getattr(pymc_model, data_name, None)
749
- if data:
748
+ if data is not None:
750
749
  found_data.append(data.name)
751
750
 
752
751
  missing_data = list(set(data_names) - set(found_data))
@@ -2233,7 +2232,9 @@ class PyMCStateSpace:
2233
2232
  if shock_trajectory is None:
2234
2233
  shock_trajectory = pt.zeros((n_steps, self.k_posdef))
2235
2234
  if Q is not None:
2236
- init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM])
2235
+ init_shock = pm.MvNormal(
2236
+ "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
2237
+ )
2237
2238
  else:
2238
2239
  init_shock = pm.Deterministic(
2239
2240
  "initial_shock",
@@ -6,11 +6,9 @@ import pytensor.tensor as pt
6
6
  from pymc import intX
7
7
  from pymc.distributions.dist_math import check_parameters
8
8
  from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9
- from pymc.distributions.multivariate import MvNormal
10
9
  from pymc.distributions.shape_utils import get_support_shape_1d
11
10
  from pymc.logprob.abstract import _logprob
12
11
  from pytensor.graph.basic import Node
13
- from pytensor.tensor.random.basic import MvNormalRV
14
12
 
15
13
  floatX = pytensor.config.floatX
16
14
  COV_ZERO_TOL = 0
@@ -49,44 +47,6 @@ def make_signature(sequence_names):
49
47
  return f"{signature},[rng]->[rng],({time},{state_and_obs})"
50
48
 
51
49
 
52
- class MvNormalSVDRV(MvNormalRV):
53
- name = "multivariate_normal"
54
- signature = "(n),(n,n)->(n)"
55
- dtype = "floatX"
56
- _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
57
-
58
-
59
- class MvNormalSVD(MvNormal):
60
- """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
61
-
62
- A JAX MvNormal robust to low-rank covariance matrices
63
- """
64
-
65
- rv_op = MvNormalSVDRV()
66
-
67
-
68
- try:
69
- import jax.random
70
-
71
- from pytensor.link.jax.dispatch.random import jax_sample_fn
72
-
73
- @jax_sample_fn.register(MvNormalSVDRV)
74
- def jax_sample_fn_mvnormal_svd(op, node):
75
- def sample_fn(rng, size, dtype, *parameters):
76
- rng_key = rng["jax_state"]
77
- rng_key, sampling_key = jax.random.split(rng_key, 2)
78
- sample = jax.random.multivariate_normal(
79
- sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
80
- )
81
- rng["jax_state"] = rng_key
82
- return (rng, sample)
83
-
84
- return sample_fn
85
-
86
- except ImportError:
87
- pass
88
-
89
-
90
50
  class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
91
51
  default_output = 1
92
52
  _print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
@@ -244,8 +204,12 @@ class _LinearGaussianStateSpace(Continuous):
244
204
  k = T.shape[0]
245
205
  a = state[:k]
246
206
 
247
- middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
248
- next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
207
+ middle_rng, a_innovation = pm.MvNormal.dist(
208
+ mu=0, cov=Q, rng=rng, method="svd"
209
+ ).owner.outputs
210
+ next_rng, y_innovation = pm.MvNormal.dist(
211
+ mu=0, cov=H, rng=middle_rng, method="svd"
212
+ ).owner.outputs
249
213
 
250
214
  a_mu = c + T @ a
251
215
  a_next = a_mu + R @ a_innovation
@@ -260,8 +224,8 @@ class _LinearGaussianStateSpace(Continuous):
260
224
  Z_init = Z_ if Z_ in non_sequences else Z_[0]
261
225
  H_init = H_ if H_ in non_sequences else H_[0]
262
226
 
263
- init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
264
- init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
227
+ init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
228
+ init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
265
229
 
266
230
  init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
267
231
 
@@ -421,7 +385,7 @@ class SequenceMvNormal(Continuous):
421
385
  rng = pytensor.shared(np.random.default_rng())
422
386
 
423
387
  def step(mu, cov, rng):
424
- new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs
388
+ new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
425
389
  return mvn, {rng: new_rng}
426
390
 
427
391
  mvn_seq, updates = pytensor.scan(
pymc_extras/version.txt CHANGED
@@ -1 +1 @@
1
- 0.2.3
1
+ 0.2.4
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pymc-extras
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Home-page: http://github.com/pymc-devs/pymc-extras
6
6
  Maintainer: PyMC Developers
@@ -20,7 +20,7 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.10
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: pymc>=5.20
23
+ Requires-Dist: pymc>=5.21.1
24
24
  Requires-Dist: scikit-learn
25
25
  Requires-Dist: better-optimize
26
26
  Provides-Extra: dask-histogram
@@ -3,7 +3,7 @@ pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,39
3
3
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
4
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
5
  pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=OrlMBNJJhvOvKIuhzaLAu928Wonf8JcYKAX1RXjh6nU,6
6
+ pymc_extras/version.txt,sha256=FyW_plJMUmXnwXHPBlaEF9OblH__ScJC8DhZR5yCM0s,6
7
7
  pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
8
8
  pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
9
  pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
@@ -18,9 +18,9 @@ pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOn
18
18
  pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
19
19
  pymc_extras/inference/laplace.py,sha256=uOZGp8ssQuhvCHV_Y_v3icsr4rhcYgr_qlr9dS7pcSM,21761
20
20
  pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
21
- pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
21
+ pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
22
22
  pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
23
- pymc_extras/inference/pathfinder/pathfinder.py,sha256=fomZ5voVcWxvhWpeIZV7IHGIJCasT1g0ivC4dC3-0GM,63694
23
+ pymc_extras/inference/pathfinder/pathfinder.py,sha256=baw8NUN4hdylM0o4JpCh32xxig-fNFLjh_W9qsvvmM0,64495
24
24
  pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
25
25
  pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
26
26
  pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -37,9 +37,9 @@ pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53C
37
37
  pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
38
38
  pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
39
39
  pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
40
- pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
40
+ pymc_extras/statespace/core/statespace.py,sha256=Tx-821UNNLqsZgHzRmwaQ6s-agp_OthqSsbfwDpA1o0,96927
41
41
  pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
42
- pymc_extras/statespace/filters/distributions.py,sha256=-9j__vRqL5hKyYFnQr5HKHA5kEFzwiuSccH4mslTOuQ,12900
42
+ pymc_extras/statespace/filters/distributions.py,sha256=ejimTFLgBFZMEznxY5zh6u4Vrqij60i0k2_sxdPcZ3A,11878
43
43
  pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
44
44
  pymc_extras/statespace/filters/kalman_smoother.py,sha256=ypH9K_88nfJ5K2Cq737aWL3p8v4UfI7MxnYs54WPdDs,4329
45
45
  pymc_extras/statespace/filters/utilities.py,sha256=iwdaYnO1cO06t_XUjLLRmqb8vwzzVH6Nx1iyZcbJL2k,1584
@@ -66,7 +66,7 @@ tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6
66
66
  tests/test_laplace.py,sha256=u4o-0y4v1emaTMYr_rOyL_EKY_bQIz0DUXFuwuDbfNg,9314
67
67
  tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
68
68
  tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
69
- tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
69
+ tests/test_pathfinder.py,sha256=-ekjetUSWnNRe7YausDvD00Cqh0zpBW3xn5z1hJ37MI,6027
70
70
  tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
71
71
  tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
72
72
  tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
@@ -98,8 +98,8 @@ tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
98
98
  tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
99
99
  tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
100
100
  tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
101
- pymc_extras-0.2.3.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
- pymc_extras-0.2.3.dist-info/METADATA,sha256=ZTiMM7hvVRF3O_liRu4Aea_EuxJc4vHfTD2CbRRQrcU,5152
103
- pymc_extras-0.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
104
- pymc_extras-0.2.3.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
- pymc_extras-0.2.3.dist-info/RECORD,,
101
+ pymc_extras-0.2.4.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
102
+ pymc_extras-0.2.4.dist-info/METADATA,sha256=ozmK251JzJsJLI9yx8NFhhVCgOy5nfcfSfE5IfTP3ok,5154
103
+ pymc_extras-0.2.4.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
104
+ pymc_extras-0.2.4.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
105
+ pymc_extras-0.2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
tests/test_pathfinder.py CHANGED
@@ -44,8 +44,8 @@ def reference_idata():
44
44
  with model:
45
45
  idata = pmx.fit(
46
46
  method="pathfinder",
47
- num_paths=50,
48
- jitter=10.0,
47
+ num_paths=10,
48
+ jitter=12.0,
49
49
  random_seed=41,
50
50
  inference_backend="pymc",
51
51
  )
@@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata):
62
62
  with model:
63
63
  idata = pmx.fit(
64
64
  method="pathfinder",
65
- num_paths=50,
66
- jitter=10.0,
65
+ num_paths=10,
66
+ jitter=12.0,
67
67
  random_seed=41,
68
68
  inference_backend=inference_backend,
69
69
  )
70
70
  else:
71
71
  idata = reference_idata
72
- np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
73
- np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
72
+ np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
73
+ np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
74
74
 
75
75
  assert idata.posterior["mu"].shape == (1, 1000)
76
76
  assert idata.posterior["tau"].shape == (1, 1000)
@@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent):
83
83
  with model:
84
84
  idata_conc = pmx.fit(
85
85
  method="pathfinder",
86
- num_paths=50,
87
- jitter=10.0,
86
+ num_paths=10,
87
+ jitter=12.0,
88
88
  random_seed=41,
89
89
  inference_backend="pymc",
90
90
  concurrent=concurrent,
@@ -108,7 +108,7 @@ def test_seed(reference_idata):
108
108
  with model:
109
109
  idata_41 = pmx.fit(
110
110
  method="pathfinder",
111
- num_paths=50,
111
+ num_paths=4,
112
112
  jitter=10.0,
113
113
  random_seed=41,
114
114
  inference_backend="pymc",
@@ -116,7 +116,7 @@ def test_seed(reference_idata):
116
116
 
117
117
  idata_123 = pmx.fit(
118
118
  method="pathfinder",
119
- num_paths=50,
119
+ num_paths=4,
120
120
  jitter=10.0,
121
121
  random_seed=123,
122
122
  inference_backend="pymc",
@@ -171,3 +171,33 @@ def test_bfgs_sample():
171
171
  assert gamma.eval().shape == (L, 2 * J, 2 * J)
172
172
  assert phi.eval().shape == (L, num_samples, N)
173
173
  assert logq.eval().shape == (L, num_samples)
174
+
175
+
176
+ @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
177
+ def test_pathfinder_importance_sampling(importance_sampling):
178
+ model = eight_schools_model()
179
+
180
+ num_paths = 4
181
+ num_draws_per_path = 300
182
+ num_draws = 750
183
+
184
+ with model:
185
+ idata = pmx.fit(
186
+ method="pathfinder",
187
+ num_paths=num_paths,
188
+ num_draws_per_path=num_draws_per_path,
189
+ num_draws=num_draws,
190
+ maxiter=5,
191
+ random_seed=41,
192
+ inference_backend="pymc",
193
+ importance_sampling=importance_sampling,
194
+ )
195
+
196
+ if importance_sampling is None:
197
+ assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
198
+ assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
199
+ assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
200
+ else:
201
+ assert idata.posterior["mu"].shape == (1, num_draws)
202
+ assert idata.posterior["tau"].shape == (1, num_draws)
203
+ assert idata.posterior["theta"].shape == (1, num_draws, 8)