pymc-extras 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,22 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  import collections
16
17
  import logging
17
18
  import time
18
- import warnings as _warnings
19
19
 
20
20
  from collections import Counter
21
21
  from collections.abc import Callable, Iterator
22
22
  from dataclasses import asdict, dataclass, field, replace
23
23
  from enum import Enum, auto
24
- from importlib.util import find_spec
25
24
  from typing import Literal, TypeAlias
26
25
 
27
26
  import arviz as az
28
- import blackjax
29
27
  import filelock
30
- import jax
31
28
  import numpy as np
32
29
  import pymc as pm
33
30
  import pytensor
@@ -42,11 +39,10 @@ from pymc.initial_point import make_initial_point_fn
42
39
  from pymc.model import modelcontext
43
40
  from pymc.model.core import Point
44
41
  from pymc.pytensorf import (
45
- compile_pymc,
42
+ compile,
46
43
  find_rng_nodes,
47
44
  reseed_rngs,
48
45
  )
49
- from pymc.sampling.jax import get_jaxified_graph
50
46
  from pymc.util import (
51
47
  CustomProgress,
52
48
  RandomSeed,
@@ -60,12 +56,14 @@ from pytensor.graph import Apply, Op, vectorize_graph
60
56
  from pytensor.tensor import TensorConstant, TensorVariable
61
57
  from rich.console import Console, Group
62
58
  from rich.padding import Padding
59
+ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
63
60
  from rich.table import Table
64
61
  from rich.text import Text
65
62
 
66
63
  # TODO: change to typing.Self after Python versions greater than 3.10
67
64
  from typing_extensions import Self
68
65
 
66
+ from pymc_extras.inference.laplace import add_data_to_inferencedata
69
67
  from pymc_extras.inference.pathfinder.importance_sampling import (
70
68
  importance_sampling as _importance_sampling,
71
69
  )
@@ -77,9 +75,6 @@ from pymc_extras.inference.pathfinder.lbfgs import (
77
75
  )
78
76
 
79
77
  logger = logging.getLogger(__name__)
80
- _warnings.filterwarnings(
81
- "ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
82
- )
83
78
 
84
79
  REGULARISATION_TERM = 1e-8
85
80
  DEFAULT_LINKER = "cvm_nogc"
@@ -104,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca
104
99
  A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
105
100
  """
106
101
 
102
+ from pymc.sampling.jax import get_jaxified_graph
103
+
107
104
  # TODO: JAX: test if we should get jaxified graph of dlogp as well
108
105
  new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
109
106
  model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
@@ -143,7 +140,7 @@ def get_logp_dlogp_of_ravel_inputs(
143
140
  [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
144
141
  model.value_vars,
145
142
  )
146
- logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs)
143
+ logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
147
144
  logp_dlogp_fn.trust_input = True
148
145
 
149
146
  return logp_dlogp_fn
@@ -155,7 +152,7 @@ def convert_flat_trace_to_idata(
155
152
  postprocessing_backend: Literal["cpu", "gpu"] = "cpu",
156
153
  inference_backend: Literal["pymc", "blackjax"] = "pymc",
157
154
  model: Model | None = None,
158
- importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
155
+ importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
159
156
  ) -> az.InferenceData:
160
157
  """convert flattened samples to arviz InferenceData format.
161
158
 
@@ -180,7 +177,7 @@ def convert_flat_trace_to_idata(
180
177
  arviz inference data object
181
178
  """
182
179
 
183
- if importance_sampling == "none":
180
+ if importance_sampling is None:
184
181
  # samples.ndim == 3 in this case, otherwise ndim == 2
185
182
  num_paths, num_pdraws, N = samples.shape
186
183
  samples = samples.reshape(-1, N)
@@ -219,10 +216,14 @@ def convert_flat_trace_to_idata(
219
216
  fn.trust_input = True
220
217
  result = fn(*list(trace.values()))
221
218
 
222
- if importance_sampling == "none":
219
+ if importance_sampling is None:
223
220
  result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]
224
221
 
225
222
  elif inference_backend == "blackjax":
223
+ import jax
224
+
225
+ from pymc.sampling.jax import get_jaxified_graph
226
+
226
227
  jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
227
228
  result = jax.vmap(jax.vmap(jax_fn))(
228
229
  *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
@@ -236,8 +237,8 @@ def convert_flat_trace_to_idata(
236
237
 
237
238
 
238
239
  def alpha_recover(
239
- x: TensorVariable, g: TensorVariable, epsilon: TensorVariable
240
- ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]:
240
+ x: TensorVariable, g: TensorVariable
241
+ ) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
241
242
  """compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates.
242
243
 
243
244
  Parameters
@@ -246,9 +247,6 @@ def alpha_recover(
246
247
  position array, shape (L+1, N)
247
248
  g : TensorVariable
248
249
  gradient array, shape (L+1, N)
249
- epsilon : float
250
- threshold for filtering updates based on inner product of position
251
- and gradient differences
252
250
 
253
251
  Returns
254
252
  -------
@@ -258,15 +256,13 @@ def alpha_recover(
258
256
  position differences, shape (L, N)
259
257
  z : TensorVariable
260
258
  gradient differences, shape (L, N)
261
- update_mask : TensorVariable
262
- mask for filtering updates, shape (L,)
263
259
 
264
260
  Notes
265
261
  -----
266
262
  shapes: L=batch_size, N=num_params
267
263
  """
268
264
 
269
- def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable:
265
+ def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable:
270
266
  # alpha_lm1: (N,)
271
267
  # s_l: (N,)
272
268
  # z_l: (N,)
@@ -280,43 +276,28 @@ def alpha_recover(
280
276
  ) # fmt:off
281
277
  return 1.0 / inv_alpha_l
282
278
 
283
- def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable:
284
- return alpha_lm1[-1]
285
-
286
- def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable:
287
- return pt.switch(
288
- update_mask_l,
289
- compute_alpha_l(alpha_lm1, s_l, z_l),
290
- return_alpha_lm1(alpha_lm1, s_l, z_l),
291
- )
292
-
293
279
  Lp1, N = x.shape
294
280
  s = pt.diff(x, axis=0)
295
281
  z = pt.diff(g, axis=0)
296
282
  alpha_l_init = pt.ones(N)
297
- sz = (s * z).sum(axis=-1)
298
- # update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1)
299
- # pt.linalg.norm does not work with JAX!!
300
- update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1))
301
283
 
302
284
  alpha, _ = pytensor.scan(
303
- fn=scan_body,
285
+ fn=compute_alpha_l,
304
286
  outputs_info=alpha_l_init,
305
- sequences=[update_mask, s, z],
287
+ sequences=[s, z],
306
288
  n_steps=Lp1 - 1,
307
289
  allow_gc=False,
308
290
  )
309
291
 
310
292
  # assert np.all(alpha.eval() > 0), "alpha cannot be negative"
311
- # alpha: (L, N), update_mask: (L, N)
312
- return alpha, s, z, update_mask
293
+ # alpha: (L, N)
294
+ return alpha, s, z
313
295
 
314
296
 
315
297
  def inverse_hessian_factors(
316
298
  alpha: TensorVariable,
317
299
  s: TensorVariable,
318
300
  z: TensorVariable,
319
- update_mask: TensorVariable,
320
301
  J: TensorConstant,
321
302
  ) -> tuple[TensorVariable, TensorVariable]:
322
303
  """compute the inverse hessian factors for the BFGS approximation.
@@ -329,8 +310,6 @@ def inverse_hessian_factors(
329
310
  position differences, shape (L, N)
330
311
  z : TensorVariable
331
312
  gradient differences, shape (L, N)
332
- update_mask : TensorVariable
333
- mask for filtering updates, shape (L,)
334
313
  J : TensorConstant
335
314
  history size for L-BFGS
336
315
 
@@ -349,30 +328,19 @@ def inverse_hessian_factors(
349
328
  # NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022)
350
329
  # NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented
351
330
 
352
- def get_chi_matrix_1(
353
- diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
354
- ) -> TensorVariable:
331
+ def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
355
332
  L, N = diff.shape
356
333
  j_last = pt.as_tensor(J - 1) # since indexing starts at 0
357
334
 
358
- def chi_update(chi_lm1, diff_l) -> TensorVariable:
335
+ def chi_update(diff_l, chi_lm1) -> TensorVariable:
359
336
  chi_l = pt.roll(chi_lm1, -1, axis=0)
360
337
  return pt.set_subtensor(chi_l[j_last], diff_l)
361
338
 
362
- def no_op(chi_lm1, diff_l) -> TensorVariable:
363
- return chi_lm1
364
-
365
- def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable:
366
- return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
367
-
368
339
  chi_init = pt.zeros((J, N))
369
340
  chi_mat, _ = pytensor.scan(
370
- fn=scan_body,
341
+ fn=chi_update,
371
342
  outputs_info=chi_init,
372
- sequences=[
373
- update_mask,
374
- diff,
375
- ],
343
+ sequences=[diff],
376
344
  allow_gc=False,
377
345
  )
378
346
 
@@ -381,19 +349,15 @@ def inverse_hessian_factors(
381
349
  # (L, N, J)
382
350
  return chi_mat
383
351
 
384
- def get_chi_matrix_2(
385
- diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant
386
- ) -> TensorVariable:
352
+ def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable:
387
353
  L, N = diff.shape
388
354
 
389
- diff_masked = update_mask[:, None] * diff
390
-
391
355
  # diff_padded: (L+J, N)
392
356
  pad_width = pt.zeros(shape=(2, 2), dtype="int32")
393
- pad_width = pt.set_subtensor(pad_width[0, 0], J)
394
- diff_padded = pt.pad(diff_masked, pad_width, mode="constant")
357
+ pad_width = pt.set_subtensor(pad_width[0, 0], J - 1)
358
+ diff_padded = pt.pad(diff, pad_width, mode="constant")
395
359
 
396
- index = pt.arange(L)[:, None] + pt.arange(J)[None, :]
360
+ index = pt.arange(L)[..., None] + pt.arange(J)[None, ...]
397
361
  index = index.reshape((L, J))
398
362
 
399
363
  chi_mat = pt.matrix_transpose(diff_padded[index])
@@ -402,8 +366,10 @@ def inverse_hessian_factors(
402
366
  return chi_mat
403
367
 
404
368
  L, N = alpha.shape
405
- S = get_chi_matrix_1(s, update_mask, J)
406
- Z = get_chi_matrix_1(z, update_mask, J)
369
+
370
+ # changed to get_chi_matrix_2 after removing update_mask
371
+ S = get_chi_matrix_2(s, J)
372
+ Z = get_chi_matrix_2(z, J)
407
373
 
408
374
  # E: (L, J, J)
409
375
  Ij = pt.eye(J)[None, ...]
@@ -488,6 +454,7 @@ def bfgs_sample_dense(
488
454
 
489
455
  N = x.shape[-1]
490
456
  IdN = pt.eye(N)[None, ...]
457
+ IdN += IdN * REGULARISATION_TERM
491
458
 
492
459
  # inverse Hessian
493
460
  H_inv = (
@@ -503,7 +470,10 @@ def bfgs_sample_dense(
503
470
 
504
471
  logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
505
472
 
506
- mu = x - pt.batched_dot(H_inv, g)
473
+ # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
474
+
475
+ batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
476
+ mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
507
477
 
508
478
  phi = pt.matrix_transpose(
509
479
  # (L, N, 1)
@@ -564,23 +534,28 @@ def bfgs_sample_sparse(
564
534
  # qr_input: (L, N, 2J)
565
535
  qr_input = inv_sqrt_alpha_diag @ beta
566
536
  (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
537
+
567
538
  IdN = pt.eye(R.shape[1])[None, ...]
539
+ IdN += IdN * REGULARISATION_TERM
540
+
568
541
  Lchol_input = IdN + R @ gamma @ pt.matrix_transpose(R)
569
542
 
543
+ # TODO: make robust Lchol calcs more robust, ie. try exceptions, increase REGULARISATION_TERM if non-finite exists
570
544
  Lchol = pt.linalg.cholesky(Lchol_input, lower=False, check_finite=False, on_error="nan")
571
545
 
572
546
  logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
573
547
  logdet += pt.sum(pt.log(alpha), axis=-1)
574
548
 
549
+ # inverse Hessian
550
+ # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
551
+ H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
552
+
575
553
  # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
576
- mu = x - (
577
- # (L, N), (L, N) -> (L, N)
578
- pt.batched_dot(alpha_diag, g)
579
- # beta @ gamma @ beta.T
580
- # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
581
- # (L, N, N), (L, N) -> (L, N)
582
- + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
583
- )
554
+
555
+ # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
556
+
557
+ batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
558
+ mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
584
559
 
585
560
  phi = pt.matrix_transpose(
586
561
  # (L, N, 1)
@@ -588,8 +563,6 @@ def bfgs_sample_sparse(
588
563
  # (L, N, N), (L, N, M) -> (L, N, M)
589
564
  + sqrt_alpha_diag
590
565
  @ (
591
- # (L, N, 2J), (L, 2J, M) -> (L, N, M)
592
- # intermediate calcs below
593
566
  # (L, N, 2J), (L, 2J, 2J) -> (L, N, 2J)
594
567
  (Q @ (Lchol - IdN))
595
568
  # (L, 2J, N), (L, N, M) -> (L, 2J, M)
@@ -777,7 +750,6 @@ def make_pathfinder_body(
777
750
  num_draws: int,
778
751
  maxcor: int,
779
752
  num_elbo_draws: int,
780
- epsilon: float,
781
753
  **compile_kwargs: dict,
782
754
  ) -> Function:
783
755
  """
@@ -793,8 +765,6 @@ def make_pathfinder_body(
793
765
  The maximum number of iterations for the L-BFGS algorithm.
794
766
  num_elbo_draws : int
795
767
  The number of draws for the Evidence Lower Bound (ELBO) estimation.
796
- epsilon : float
797
- The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L.
798
768
  compile_kwargs : dict
799
769
  Additional keyword arguments for the PyTensor compiler.
800
770
 
@@ -819,11 +789,10 @@ def make_pathfinder_body(
819
789
 
820
790
  num_draws = pt.constant(num_draws, "num_draws", dtype="int32")
821
791
  num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32")
822
- epsilon = pt.constant(epsilon, "epsilon", dtype="float64")
823
792
  maxcor = pt.constant(maxcor, "maxcor", dtype="int32")
824
793
 
825
- alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
826
- beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor)
794
+ alpha, s, z = alpha_recover(x_full, g_full)
795
+ beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor)
827
796
 
828
797
  # ignore initial point - x, g: (L, N)
829
798
  x = x_full[1:]
@@ -854,7 +823,7 @@ def make_pathfinder_body(
854
823
 
855
824
  # return psi, logP_psi, logQ_psi, elbo_argmax
856
825
 
857
- pathfinder_body_fn = compile_pymc(
826
+ pathfinder_body_fn = compile(
858
827
  [x_full, g_full],
859
828
  [psi, logP_psi, logQ_psi, elbo_argmax],
860
829
  **compile_kwargs,
@@ -933,11 +902,11 @@ def make_single_pathfinder_fn(
933
902
  x_base = DictToArrayBijection.map(ip).data
934
903
 
935
904
  # lbfgs
936
- lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
905
+ lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon)
937
906
 
938
907
  # pathfinder body
939
908
  pathfinder_body_fn = make_pathfinder_body(
940
- logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
909
+ logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs
941
910
  )
942
911
  rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
943
912
 
@@ -949,8 +918,8 @@ def make_single_pathfinder_fn(
949
918
  x0 = x_base + jitter_value
950
919
  x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)
951
920
 
952
- if lbfgs_status == LBFGSStatus.INIT_FAILED:
953
- raise LBFGSInitFailed()
921
+ if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}:
922
+ raise LBFGSInitFailed(lbfgs_status)
954
923
  elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
955
924
  raise LBFGSException()
956
925
 
@@ -1188,7 +1157,7 @@ class MultiPathfinderResult:
1188
1157
  elbo_argmax: NDArray | None = None
1189
1158
  lbfgs_status: Counter = field(default_factory=Counter)
1190
1159
  path_status: Counter = field(default_factory=Counter)
1191
- importance_sampling: str = "none"
1160
+ importance_sampling: str | None = "psis"
1192
1161
  warnings: list[str] = field(default_factory=list)
1193
1162
  pareto_k: float | None = None
1194
1163
 
@@ -1257,7 +1226,7 @@ class MultiPathfinderResult:
1257
1226
  def with_importance_sampling(
1258
1227
  self,
1259
1228
  num_draws: int,
1260
- method: Literal["psis", "psir", "identity", "none"] | None,
1229
+ method: Literal["psis", "psir", "identity"] | None,
1261
1230
  random_seed: int | None = None,
1262
1231
  ) -> Self:
1263
1232
  """perform importance sampling"""
@@ -1388,15 +1357,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
1388
1357
  warnings = []
1389
1358
 
1390
1359
  lbfgs_status_message = {
1391
- LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1392
- LBFGSStatus.INIT_FAILED: "LBFGS failed to initialise. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1393
- LBFGSStatus.DIVERGED: "LBFGS diverged to infinity. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1360
+ LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
1361
+ LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
1362
+ LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1363
+ LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1364
+ LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
1394
1365
  }
1395
1366
 
1396
1367
  path_status_message = {
1397
- PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
1398
- PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1399
- PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1368
+ PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
1369
+ PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
1400
1370
  }
1401
1371
 
1402
1372
  for lbfgs_status in mpr.lbfgs_status:
@@ -1423,7 +1393,7 @@ def multipath_pathfinder(
1423
1393
  num_elbo_draws: int,
1424
1394
  jitter: float,
1425
1395
  epsilon: float,
1426
- importance_sampling: Literal["psis", "psir", "identity", "none"] | None,
1396
+ importance_sampling: Literal["psis", "psir", "identity"] | None,
1427
1397
  progressbar: bool,
1428
1398
  concurrent: Literal["thread", "process"] | None,
1429
1399
  random_seed: RandomSeed,
@@ -1459,8 +1429,14 @@ def multipath_pathfinder(
1459
1429
  Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1460
1430
  epsilon: float
1461
1431
  value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1462
- importance_sampling : str, optional
1463
- importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1432
+ importance_sampling : str, None, optional
1433
+ Method to apply sampling based on log importance weights (logP - logQ).
1434
+ "psis" : Pareto Smoothed Importance Sampling (default)
1435
+ Recommended for more stable results.
1436
+ "psir" : Pareto Smoothed Importance Resampling
1437
+ Less stable than PSIS.
1438
+ "identity" : Applies log importance weights directly without resampling.
1439
+ None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1464
1440
  progressbar : bool, optional
1465
1441
  Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time.
1466
1442
  random_seed : RandomSeed, optional
@@ -1482,12 +1458,6 @@ def multipath_pathfinder(
1482
1458
  The result containing samples and other information from the Multi-Path Pathfinder algorithm.
1483
1459
  """
1484
1460
 
1485
- valid_importance_sampling = ["psis", "psir", "identity", "none", None]
1486
- if importance_sampling is None:
1487
- importance_sampling = "none"
1488
- if importance_sampling.lower() not in valid_importance_sampling:
1489
- raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1490
-
1491
1461
  *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1)
1492
1462
 
1493
1463
  pathfinder_config = PathfinderConfig(
@@ -1521,12 +1491,20 @@ def multipath_pathfinder(
1521
1491
  results = []
1522
1492
  compute_start = time.time()
1523
1493
  try:
1524
- with CustomProgress(
1494
+ desc = f"Paths Complete: {{path_idx}}/{num_paths}"
1495
+ progress = CustomProgress(
1496
+ "[progress.description]{task.description}",
1497
+ BarColumn(),
1498
+ "[progress.percentage]{task.percentage:>3.0f}%",
1499
+ TimeRemainingColumn(),
1500
+ TextColumn("/"),
1501
+ TimeElapsedColumn(),
1525
1502
  console=Console(theme=default_progress_theme),
1526
1503
  disable=not progressbar,
1527
- ) as progress:
1528
- task = progress.add_task("Fitting", total=num_paths)
1529
- for result in generator:
1504
+ )
1505
+ with progress:
1506
+ task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
1507
+ for path_idx, result in enumerate(generator, start=1):
1530
1508
  try:
1531
1509
  if isinstance(result, Exception):
1532
1510
  raise result
@@ -1552,7 +1530,15 @@ def multipath_pathfinder(
1552
1530
  lbfgs_status=LBFGSStatus.LBFGS_FAILED,
1553
1531
  )
1554
1532
  )
1555
- progress.update(task, advance=1)
1533
+ finally:
1534
+ # TODO: display LBFGS and Path Status in real time
1535
+ progress.update(
1536
+ task,
1537
+ description=desc.format(path_idx=path_idx),
1538
+ completed=path_idx,
1539
+ )
1540
+ # Ensure the progress bar visually reaches 100% and shows 'Completed'
1541
+ progress.update(task, completed=num_paths, description="Completed")
1556
1542
  except (KeyboardInterrupt, StopIteration) as e:
1557
1543
  # if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
1558
1544
  if isinstance(e, StopIteration):
@@ -1602,11 +1588,11 @@ def fit_pathfinder(
1602
1588
  maxiter: int = 1000, # L^max
1603
1589
  ftol: float = 1e-5,
1604
1590
  gtol: float = 1e-8,
1605
- maxls=1000,
1591
+ maxls: int = 1000,
1606
1592
  num_elbo_draws: int = 10, # K
1607
1593
  jitter: float = 2.0,
1608
1594
  epsilon: float = 1e-8,
1609
- importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis",
1595
+ importance_sampling: Literal["psis", "psir", "identity"] | None = "psis",
1610
1596
  progressbar: bool = True,
1611
1597
  concurrent: Literal["thread", "process"] | None = None,
1612
1598
  random_seed: RandomSeed | None = None,
@@ -1614,6 +1600,7 @@ def fit_pathfinder(
1614
1600
  inference_backend: Literal["pymc", "blackjax"] = "pymc",
1615
1601
  pathfinder_kwargs: dict = {},
1616
1602
  compile_kwargs: dict = {},
1603
+ initvals: dict | None = None,
1617
1604
  ) -> az.InferenceData:
1618
1605
  """
1619
1606
  Fit the Pathfinder Variational Inference algorithm.
@@ -1646,8 +1633,15 @@ def fit_pathfinder(
1646
1633
  Amount of jitter to apply to initial points (default is 2.0). Note that Pathfinder may be highly sensitive to the jitter value. It is recommended to increase num_paths when increasing the jitter value.
1647
1634
  epsilon: float
1648
1635
  value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8).
1649
- importance_sampling : str, optional
1650
- importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N).
1636
+ importance_sampling : str, None, optional
1637
+ Method to apply sampling based on log importance weights (logP - logQ).
1638
+ Options are:
1639
+
1640
+ - "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
1641
+ - "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
1642
+ - "identity" : Applies log importance weights directly without resampling.
1643
+ - None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1644
+
1651
1645
  progressbar : bool, optional
1652
1646
  Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
1653
1647
  random_seed : RandomSeed, optional
@@ -1662,10 +1656,13 @@ def fit_pathfinder(
1662
1656
  Additional keyword arguments for the Pathfinder algorithm.
1663
1657
  compile_kwargs
1664
1658
  Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1659
+ initvals: dict | None = None
1660
+ Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1661
+ If None, the model's default initial values are used.
1665
1662
 
1666
1663
  Returns
1667
1664
  -------
1668
- arviz.InferenceData
1665
+ :class:`~arviz.InferenceData`
1669
1666
  The inference data containing the results of the Pathfinder algorithm.
1670
1667
 
1671
1668
  References
@@ -1674,6 +1671,23 @@ def fit_pathfinder(
1674
1671
  """
1675
1672
 
1676
1673
  model = modelcontext(model)
1674
+
1675
+ if initvals is not None:
1676
+ model = pm.model.fgraph.clone_model(model) # Create a clone of the model
1677
+ for (
1678
+ rv_name,
1679
+ ivals,
1680
+ ) in initvals.items(): # Set the initial values for the variables in the clone
1681
+ model.set_initval(model.named_vars[rv_name], ivals)
1682
+
1683
+ valid_importance_sampling = {"psis", "psir", "identity", None}
1684
+
1685
+ if importance_sampling is not None:
1686
+ importance_sampling = importance_sampling.lower()
1687
+
1688
+ if importance_sampling not in valid_importance_sampling:
1689
+ raise ValueError(f"Invalid importance sampling method: {importance_sampling}")
1690
+
1677
1691
  N = DictToArrayBijection.map(model.initial_point()).data.shape[0]
1678
1692
 
1679
1693
  if maxcor is None:
@@ -1704,8 +1718,9 @@ def fit_pathfinder(
1704
1718
  )
1705
1719
  pathfinder_samples = mp_result.samples
1706
1720
  elif inference_backend == "blackjax":
1707
- if find_spec("blackjax") is None:
1708
- raise RuntimeError("Need BlackJAX to use `pathfinder`")
1721
+ import blackjax
1722
+ import jax
1723
+
1709
1724
  if version.parse(blackjax.__version__).major < 1:
1710
1725
  raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
1711
1726
 
@@ -1743,4 +1758,7 @@ def fit_pathfinder(
1743
1758
  model=model,
1744
1759
  importance_sampling=importance_sampling,
1745
1760
  )
1761
+
1762
+ idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1763
+
1746
1764
  return idata
@@ -28,7 +28,6 @@ from pymc_extras.statespace.filters import (
28
28
  )
29
29
  from pymc_extras.statespace.filters.distributions import (
30
30
  LinearGaussianStateSpace,
31
- MvNormalSVD,
32
31
  SequenceMvNormal,
33
32
  )
34
33
  from pymc_extras.statespace.filters.utilities import stabilize
@@ -707,7 +706,7 @@ class PyMCStateSpace:
707
706
  with pymc_model:
708
707
  for param_name in self.param_names:
709
708
  param = getattr(pymc_model, param_name, None)
710
- if param:
709
+ if param is not None:
711
710
  found_params.append(param.name)
712
711
 
713
712
  missing_params = list(set(self.param_names) - set(found_params))
@@ -746,7 +745,7 @@ class PyMCStateSpace:
746
745
  with pymc_model:
747
746
  for data_name in data_names:
748
747
  data = getattr(pymc_model, data_name, None)
749
- if data:
748
+ if data is not None:
750
749
  found_data.append(data.name)
751
750
 
752
751
  missing_data = list(set(data_names) - set(found_data))
@@ -2233,7 +2232,9 @@ class PyMCStateSpace:
2233
2232
  if shock_trajectory is None:
2234
2233
  shock_trajectory = pt.zeros((n_steps, self.k_posdef))
2235
2234
  if Q is not None:
2236
- init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM])
2235
+ init_shock = pm.MvNormal(
2236
+ "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
2237
+ )
2237
2238
  else:
2238
2239
  init_shock = pm.Deterministic(
2239
2240
  "initial_shock",